From 991b540c8bfac259987b96f46dea6122820403a1 Mon Sep 17 00:00:00 2001 From: mpl Date: Fri, 25 Apr 2014 19:47:35 +0200 Subject: [PATCH] importer/twitter: refactor credentials isolation -store in the importer account the temporary credentials needed during setup. -store in the importer account the access token, and the access token secret, and guard them. They are needed at the end of the setup (to get the user's info) and all along a run. -not share the same oauth.Client with the whole package. Create a new one everytime it's needed in the setup. Create a new one at the beginning of a run too, and it's ok to reuse it all along, since it's used read-only. Change-Id: I4823b35abf8a5170f63ca502026f6c8f98e2c1e7 --- pkg/importer/twitter/twitter.go | 160 +++++++++++++++++---------- pkg/importer/twitter/twitter_test.go | 8 +- 2 files changed, 100 insertions(+), 68 deletions(-) diff --git a/pkg/importer/twitter/twitter.go b/pkg/importer/twitter/twitter.go index 6817dbf19..95bb47710 100644 --- a/pkg/importer/twitter/twitter.go +++ b/pkg/importer/twitter/twitter.go @@ -25,7 +25,6 @@ import ( "net/url" "strconv" "strings" - "sync" "time" "camlistore.org/pkg/context" @@ -43,11 +42,15 @@ const ( userInfoAPIPath = "account/verify_credentials.json" // Permanode attributes on account node: - acctAttrUserID = "twitterUserID" - acctAttrScreenName = "twitterScreenName" - acctAttrUserFirst = "twitterFirstName" - acctAttrUserLast = "twitterLastName" - acctAttrAccessToken = "oauthAccessToken" + acctAttrUserID = "twitterUserID" + acctAttrScreenName = "twitterScreenName" + acctAttrUserFirst = "twitterFirstName" + acctAttrUserLast = "twitterLastName" + // TODO(mpl): refactor these 4 below into an oauth package when doing flickr. + acctAttrTempToken = "oauthTempToken" + acctAttrTempSecret = "oauthTempSecret" + acctAttrAccessToken = "oauthAccessToken" + acctAttrAccessTokenSecret = "oauthAccessTokenSecret" tweetRequestLimit = 200 // max number of tweets we can get in a user_timeline request ) @@ -58,32 +61,7 @@ func init() { var _ importer.ImporterSetupHTMLer = (*imp)(nil) -var ( - oauthClient = &oauth.Client{ - TemporaryCredentialRequestURI: temporaryCredentialRequestURL, - ResourceOwnerAuthorizationURI: resourceOwnerAuthorizationURL, - TokenRequestURI: tokenRequestURL, - } -) - -type imp struct { - // cred are the various credentials passed around during OAuth. First the temporary - // ones, then the access token and secret. - mu sync.Mutex // guards credsVal - credsVal *oauth.Credentials -} - -func (im *imp) creds() *oauth.Credentials { - im.mu.Lock() - defer im.mu.Unlock() - return im.credsVal -} - -func (im *imp) setCreds(v *oauth.Credentials) { - im.mu.Lock() - defer im.mu.Unlock() - im.credsVal = v -} +type imp struct{} func (im *imp) NeedsAPIKey() bool { return true } @@ -129,13 +107,38 @@ func (im *imp) AccountSetupHTML(host *importer.Host) string { // A run is our state for a given run of the importer. type run struct { *importer.RunContext - im *imp + im *imp + oauthClient *oauth.Client // No need to guard, used read-only. + accessCreds *oauth.Credentials // No need to guard, used read-only. } func (im *imp) Run(ctx *importer.RunContext) error { + clientId, secret, err := ctx.Credentials() + if err != nil { + return fmt.Errorf("no API credentials: %v", err) + } + accountNode := ctx.AccountNode() + accessToken := accountNode.Attr(acctAttrAccessToken) + accessSecret := accountNode.Attr(acctAttrAccessTokenSecret) + if accessToken == "" || accessSecret == "" { + return errors.New("access credentials not found") + } r := &run{ RunContext: ctx, im: im, + oauthClient: &oauth.Client{ + TemporaryCredentialRequestURI: temporaryCredentialRequestURL, + ResourceOwnerAuthorizationURI: resourceOwnerAuthorizationURL, + TokenRequestURI: tokenRequestURL, + Credentials: oauth.Credentials{ + Token: clientId, + Secret: secret, + }, + }, + accessCreds: &oauth.Credentials{ + Token: accessToken, + Secret: accessSecret, + }, } userID := ctx.AccountNode().Attr(acctAttrUserID) if userID == "" { @@ -165,7 +168,7 @@ func (r *run) importTweets(userID string) error { } var resp []*tweetItem - if err := r.im.doAPI(r.Context, &resp, "statuses/user_timeline.json", + if err := doAPI(r.Context, r.oauthClient, r.accessCreds, &resp, "statuses/user_timeline.json", "user_id", userID, "count", strconv.Itoa(tweetRequestLimit), "max_id", maxId); err != nil { @@ -243,9 +246,9 @@ type userInfo struct { Name string `json:"name,omitempty"` } -func (im *imp) getUserInfo(ctx *context.Context) (userInfo, error) { +func getUserInfo(ctx *context.Context, oauthClient *oauth.Client, creds *oauth.Credentials) (userInfo, error) { var ui userInfo - if err := im.doAPI(ctx, &ui, userInfoAPIPath); err != nil { + if err := doAPI(ctx, oauthClient, creds, &ui, userInfoAPIPath); err != nil { return ui, err } if ui.ID == "" { @@ -254,13 +257,16 @@ func (im *imp) getUserInfo(ctx *context.Context) (userInfo, error) { return ui, nil } -func (im *imp) doAPI(ctx *context.Context, result interface{}, apiPath string, keyval ...string) error { +func doAPI(ctx *context.Context, oauthClient *oauth.Client, creds *oauth.Credentials, result interface{}, apiPath string, keyval ...string) error { if len(keyval)%2 == 1 { panic("Incorrect number of keyval arguments. must be even.") } - if im.creds() == nil { - return fmt.Errorf("No authentication creds") + if creds == nil { + return errors.New("No authentication creds") + } + if oauthClient == nil { + return errors.New("No authentication client") } form := url.Values{} @@ -271,7 +277,7 @@ func (im *imp) doAPI(ctx *context.Context, result interface{}, apiPath string, k } fullURL := apiURL + apiPath - res, err := im.doGet(ctx, fullURL, form) + res, err := doGet(ctx, oauthClient, creds, fullURL, form) if err != nil { return err } @@ -282,11 +288,13 @@ func (im *imp) doAPI(ctx *context.Context, result interface{}, apiPath string, k return nil } -func (im *imp) doGet(ctx *context.Context, url string, form url.Values) (*http.Response, error) { - creds := im.creds() +func doGet(ctx *context.Context, oauthClient *oauth.Client, creds *oauth.Credentials, url string, form url.Values) (*http.Response, error) { if creds == nil { return nil, errors.New("No OAuth credentials. Not logged in?") } + if creds == nil { + return nil, errors.New("No OAuth client.") + } res, err := oauthClient.Get(ctx.HTTPClient(), creds, url, form) if err != nil { return nil, fmt.Errorf("Error fetching %s: %v", url, err) @@ -297,31 +305,43 @@ func (im *imp) doGet(ctx *context.Context, url string, form url.Values) (*http.R return res, nil } -func auth(ctx *importer.SetupContext) (*oauth.Credentials, error) { +func newOauthClient(ctx *importer.SetupContext) (*oauth.Client, error) { clientId, secret, err := ctx.Credentials() if err != nil { return nil, err } - return &oauth.Credentials{ - Token: clientId, - Secret: secret, + return &oauth.Client{ + TemporaryCredentialRequestURI: temporaryCredentialRequestURL, + ResourceOwnerAuthorizationURI: resourceOwnerAuthorizationURL, + TokenRequestURI: tokenRequestURL, + Credentials: oauth.Credentials{ + Token: clientId, + Secret: secret, + }, }, nil } func (im *imp) ServeSetup(w http.ResponseWriter, r *http.Request, ctx *importer.SetupContext) error { - cred, err := auth(ctx) + oauthClient, err := newOauthClient(ctx) if err != nil { - err = fmt.Errorf("Error getting API credentials: %v", err) + err = fmt.Errorf("error getting OAuth client: %v", err) httputil.ServeError(w, r, err) return err } - oauthClient.Credentials = *cred tempCred, err := oauthClient.RequestTemporaryCredentials(ctx.HTTPClient(), ctx.CallbackURL(), nil) if err != nil { err = fmt.Errorf("Error getting temp cred: %v", err) httputil.ServeError(w, r, err) + return err + } + if err := ctx.AccountNode.SetAttrs( + acctAttrTempToken, tempCred.Token, + acctAttrTempSecret, tempCred.Secret, + ); err != nil { + err = fmt.Errorf("Error saving temp creds: %v", err) + httputil.ServeError(w, r, err) + return err } - im.setCreds(tempCred) authURL := oauthClient.AuthorizationURL(tempCred, nil) http.Redirect(w, r, authURL, 302) @@ -329,19 +349,32 @@ func (im *imp) ServeSetup(w http.ResponseWriter, r *http.Request, ctx *importer. } func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *importer.SetupContext) { - creds := im.creds() - if creds == nil { - log.Printf("twitter: nil creds in callback") - httputil.BadRequestError(w, "nil creds in callback") + tempToken := ctx.AccountNode.Attr(acctAttrTempToken) + tempSecret := ctx.AccountNode.Attr(acctAttrTempSecret) + if tempToken == "" || tempSecret == "" { + log.Printf("twitter: no temp creds in callback") + httputil.BadRequestError(w, "no temp creds in callback") return } - if creds.Token != r.FormValue("oauth_token") { - log.Printf("unexpected oauth_token: got %v, want %v", r.FormValue("oauth_token"), creds.Token) + if tempToken != r.FormValue("oauth_token") { + log.Printf("unexpected oauth_token: got %v, want %v", r.FormValue("oauth_token"), tempToken) httputil.BadRequestError(w, "unexpected oauth_token") return } - - tokenCred, vals, err := oauthClient.RequestToken(ctx.Context.HTTPClient(), creds, r.FormValue("oauth_verifier")) + oauthClient, err := newOauthClient(ctx) + if err != nil { + err = fmt.Errorf("error getting OAuth client: %v", err) + httputil.ServeError(w, r, err) + return + } + tokenCred, vals, err := oauthClient.RequestToken( + ctx.Context.HTTPClient(), + &oauth.Credentials{ + Token: tempToken, + Secret: tempSecret, + }, + r.FormValue("oauth_verifier"), + ) if err != nil { httputil.ServeError(w, r, fmt.Errorf("Error getting request token: %v ", err)) return @@ -351,9 +384,15 @@ func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *import httputil.ServeError(w, r, fmt.Errorf("Couldn't get user id: %v", err)) return } - im.setCreds(tokenCred) + if err := ctx.AccountNode.SetAttrs( + acctAttrAccessToken, tokenCred.Token, + acctAttrAccessTokenSecret, tokenCred.Secret, + ); err != nil { + httputil.ServeError(w, r, fmt.Errorf("Error setting token attributes: %v", err)) + return + } - u, err := im.getUserInfo(ctx.Context) + u, err := getUserInfo(ctx.Context, oauthClient, tokenCred) if err != nil { httputil.ServeError(w, r, fmt.Errorf("Couldn't get user info: %v", err)) return @@ -370,7 +409,6 @@ func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *import acctAttrUserFirst, firstName, acctAttrUserLast, lastName, acctAttrScreenName, u.ScreenName, - acctAttrAccessToken, tokenCred.Token, ); err != nil { httputil.ServeError(w, r, fmt.Errorf("Error setting attribute: %v", err)) return diff --git a/pkg/importer/twitter/twitter_test.go b/pkg/importer/twitter/twitter_test.go index 29f00b5af..4748caf11 100644 --- a/pkg/importer/twitter/twitter_test.go +++ b/pkg/importer/twitter/twitter_test.go @@ -30,19 +30,13 @@ import ( ) func TestGetUserID(t *testing.T) { - im := &imp{ - credsVal: &oauth.Credentials{ - Token: "foo", - Secret: "bar", - }, - } ctx := context.New() ctx.SetHTTPClient(&http.Client{ Transport: newFakeTransport(map[string]func() *http.Response{ apiURL + userInfoAPIPath: fileResponder(filepath.FromSlash("testdata/verify_credentials-res.json")), }), }) - inf, err := im.getUserInfo(ctx) + inf, err := getUserInfo(ctx, &oauth.Client{}, &oauth.Credentials{}) if err != nil { t.Fatal(err) }