Re-auth when access token expires
This commit is contained in:
parent
8066bcbad4
commit
480af1016a
|
@ -166,8 +166,15 @@ func (m *Manager) Auth(username, password string) (*protonmail.Client, openpgp.E
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
// authenticate updates cachedAuth with the new refresh token
|
||||
c := m.newClient()
|
||||
c.ReAuth = func() error {
|
||||
if _, err := authenticate(c, &cachedAuth); err != nil {
|
||||
return err
|
||||
}
|
||||
return EncryptAndSave(&cachedAuth, username, &secretKey)
|
||||
}
|
||||
|
||||
// authenticate updates cachedAuth with the new refresh token
|
||||
privateKeys, err := authenticate(c, &cachedAuth)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
|
@ -48,12 +49,20 @@ type Client struct {
|
|||
ClientSecret string
|
||||
|
||||
HTTPClient *http.Client
|
||||
ReAuth func() error
|
||||
|
||||
uid string
|
||||
accessToken string
|
||||
keyRing openpgp.EntityList
|
||||
}
|
||||
|
||||
func (c *Client) setRequestAuthorization(req *http.Request) {
|
||||
if c.uid != "" && c.accessToken != "" {
|
||||
req.Header.Set("X-Pm-Uid", c.uid)
|
||||
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequest(method, c.RootURL+path, body)
|
||||
if err != nil {
|
||||
|
@ -62,25 +71,26 @@ func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request,
|
|||
|
||||
req.Header.Set("X-Pm-Appversion", c.AppVersion)
|
||||
req.Header.Set(headerAPIVersion, strconv.Itoa(Version))
|
||||
|
||||
if c.uid != "" && c.accessToken != "" {
|
||||
req.Header.Set("X-Pm-Uid", c.uid)
|
||||
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
||||
}
|
||||
|
||||
c.setRequestAuthorization(req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (c *Client) newJSONRequest(method, path string, body interface{}) (*http.Request, error) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(body); err != nil {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := c.newRequest(method, path, &b)
|
||||
b := buf.Bytes()
|
||||
|
||||
req, err := c.newRequest(method, path, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return ioutil.NopCloser(bytes.NewReader(b)), nil
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
|
@ -89,7 +99,33 @@ func (c *Client) do(req *http.Request) (*http.Response, error) {
|
|||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return httpClient.Do(req)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Check if access token has expired
|
||||
_, hasAuth := req.Header["Authorization"]
|
||||
canRetry := req.Body == nil || req.GetBody != nil
|
||||
if resp.StatusCode == http.StatusUnauthorized && hasAuth && c.ReAuth != nil && canRetry {
|
||||
resp.Body.Close()
|
||||
c.accessToken = ""
|
||||
if err := c.ReAuth(); err != nil {
|
||||
return resp, err
|
||||
}
|
||||
c.setRequestAuthorization(req) // Access token has changed
|
||||
if req.Body != nil {
|
||||
body, err := req.GetBody()
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
req.Body = body
|
||||
}
|
||||
return c.do(req)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(req *http.Request, respData interface{}) error {
|
||||
|
|
Loading…
Reference in New Issue