diff --git a/reddit/app.go b/reddit/app.go index 787f255..9b30aae 100644 --- a/reddit/app.go +++ b/reddit/app.go @@ -1,6 +1,10 @@ package reddit -import "fmt" +import ( + "fmt" + + "golang.org/x/oauth2" +) var ( errMissingOauthCredentials = fmt.Errorf("missing oauth credentials") @@ -24,10 +28,13 @@ type App struct { // tokenURL is the url of the token request location for OAuth2. tokenURL string + + // If token is specified, username/password authentication is skipped + Token *oauth2.Token } func (a App) unauthenticated() bool { - return a.ID == "" || a.Secret == "" + return a.Token == nil && (a.ID == "" || a.Secret == "") } func (a App) validateAuth() error { @@ -35,6 +42,10 @@ func (a App) validateAuth() error { return errMissingOauthCredentials } + if a.Token != nil { + return nil + } + if a.Password != "" && a.Username == "" { return errMissingUsername } diff --git a/reddit/app_test.go b/reddit/app_test.go index 380da69..2d9f901 100644 --- a/reddit/app_test.go +++ b/reddit/app_test.go @@ -2,6 +2,8 @@ package reddit import ( "testing" + + "golang.org/x/oauth2" ) func TestAppUnauthenticated(t *testing.T) { @@ -9,12 +11,14 @@ func TestAppUnauthenticated(t *testing.T) { input App output bool }{ - {App{"y", "", "", "", ""}, true}, - {App{"", "y", "", "", ""}, true}, - {App{"y", "y", "", "", ""}, false}, - {App{"y", "y", "y", "", ""}, false}, - {App{"y", "y", "", "y", ""}, false}, - {App{"y", "y", "y", "y", ""}, false}, + {App{"y", "", "", "", "", nil}, true}, + {App{"", "y", "", "", "", nil}, true}, + {App{"y", "", "", "", "", &oauth2.Token{}}, false}, + {App{"", "y", "", "", "", &oauth2.Token{}}, false}, + {App{"y", "y", "", "", "", nil}, false}, + {App{"y", "y", "y", "", "", nil}, false}, + {App{"y", "y", "", "y", "", nil}, false}, + {App{"y", "y", "y", "y", "", nil}, false}, } { if actual := test.input.unauthenticated(); actual != test.output { t.Errorf("wrong on %d; wanted %v", i, test.output) @@ -27,13 +31,18 @@ func TestAppValidateAuth(t *testing.T) { input App output error }{ - {App{"", "", "", "", ""}, errMissingOauthCredentials}, - {App{"y", "", "", "", ""}, errMissingOauthCredentials}, - {App{"", "y", "", "", ""}, errMissingOauthCredentials}, - {App{"y", "y", "y", "", ""}, errMissingPassword}, - {App{"y", "y", "", "y", ""}, errMissingUsername}, - {App{"y", "y", "", "", ""}, nil}, - {App{"y", "y", "y", "y", ""}, nil}, + {App{"", "", "", "", "", nil}, errMissingOauthCredentials}, + {App{"y", "", "", "", "", nil}, errMissingOauthCredentials}, + {App{"", "y", "", "", "", nil}, errMissingOauthCredentials}, + {App{"y", "y", "y", "", "", nil}, errMissingPassword}, + {App{"y", "y", "", "y", "", nil}, errMissingUsername}, + {App{"", "", "", "", "", &oauth2.Token{}}, nil}, + {App{"y", "", "", "", "", &oauth2.Token{}}, nil}, + {App{"", "y", "", "", "", &oauth2.Token{}}, nil}, + {App{"y", "y", "y", "", "", &oauth2.Token{}}, nil}, + {App{"y", "y", "", "y", "", &oauth2.Token{}}, nil}, + {App{"y", "y", "", "", "", nil}, nil}, + {App{"y", "y", "y", "y", "", nil}, nil}, } { if actual := test.input.validateAuth(); actual != test.output { t.Errorf("wrong on %d; wanted %v", i, test.output) diff --git a/reddit/appclient.go b/reddit/appclient.go index 0478363..af2b967 100644 --- a/reddit/appclient.go +++ b/reddit/appclient.go @@ -29,7 +29,7 @@ func (a *appClient) Do(req *http.Request) ([]byte, error) { func (a *appClient) authorize() error { ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, a.cli) - if a.cfg.app.Username == "" || a.cfg.app.Password == "" { + if a.cfg.app.unauthenticated() { a.baseClient.cli = a.clientCredentialsClient(ctx) return nil } @@ -41,13 +41,22 @@ func (a *appClient) authorize() error { Scopes: oauthScopes, } - token, err := cfg.PasswordCredentialsToken( - ctx, - a.cfg.app.Username, - a.cfg.app.Password, - ) + var token *oauth2.Token + var err error + + if a.cfg.app.Token != nil { + token = a.cfg.app.Token + err = nil + } else { + token, err = cfg.PasswordCredentialsToken( + ctx, + a.cfg.app.Username, + a.cfg.app.Password, + ) + } a.baseClient.cli = cfg.Client(ctx, token) + return err }