diff --git a/auth/auth.go b/auth/auth.go index 5132051..315323c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -166,8 +166,15 @@ func (m *Manager) Auth(username, password string) (*protonmail.Client, openpgp.E return nil, nil, err } - // authenticate updates cachedAuth with the new refresh token c := m.newClient() + c.ReAuth = func() error { + if _, err := authenticate(c, &cachedAuth); err != nil { + return err + } + return EncryptAndSave(&cachedAuth, username, &secretKey) + } + + // authenticate updates cachedAuth with the new refresh token privateKeys, err := authenticate(c, &cachedAuth) if err != nil { return nil, nil, err diff --git a/protonmail/protonmail.go b/protonmail/protonmail.go index 58dad80..c61a769 100644 --- a/protonmail/protonmail.go +++ b/protonmail/protonmail.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "io" + "io/ioutil" "net/http" "strconv" @@ -48,12 +49,20 @@ type Client struct { ClientSecret string HTTPClient *http.Client + ReAuth func() error uid string accessToken string keyRing openpgp.EntityList } +func (c *Client) setRequestAuthorization(req *http.Request) { + if c.uid != "" && c.accessToken != "" { + req.Header.Set("X-Pm-Uid", c.uid) + req.Header.Set("Authorization", "Bearer "+c.accessToken) + } +} + func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) { req, err := http.NewRequest(method, c.RootURL+path, body) if err != nil { @@ -62,25 +71,26 @@ func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, req.Header.Set("X-Pm-Appversion", c.AppVersion) req.Header.Set(headerAPIVersion, strconv.Itoa(Version)) - - if c.uid != "" && c.accessToken != "" { - req.Header.Set("X-Pm-Uid", c.uid) - req.Header.Set("Authorization", "Bearer "+c.accessToken) - } - + c.setRequestAuthorization(req) return req, nil } func (c *Client) newJSONRequest(method, path string, body interface{}) (*http.Request, error) { - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(body); err != nil { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(body); err != nil { return nil, err } - req, err := c.newRequest(method, path, &b) + b := buf.Bytes() + + req, err := c.newRequest(method, path, bytes.NewReader(b)) if err != nil { return nil, err } + req.Header.Set("Content-Type", "application/json") + req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(b)), nil + } return req, nil } @@ -89,7 +99,33 @@ func (c *Client) do(req *http.Request) (*http.Response, error) { if httpClient == nil { httpClient = http.DefaultClient } - return httpClient.Do(req) + + resp, err := httpClient.Do(req) + if err != nil { + return resp, err + } + + // Check if access token has expired + _, hasAuth := req.Header["Authorization"] + canRetry := req.Body == nil || req.GetBody != nil + if resp.StatusCode == http.StatusUnauthorized && hasAuth && c.ReAuth != nil && canRetry { + resp.Body.Close() + c.accessToken = "" + if err := c.ReAuth(); err != nil { + return resp, err + } + c.setRequestAuthorization(req) // Access token has changed + if req.Body != nil { + body, err := req.GetBody() + if err != nil { + return resp, err + } + req.Body = body + } + return c.do(req) + } + + return resp, nil } func (c *Client) doJSON(req *http.Request, respData interface{}) error {