diff --git a/auth/auth.go b/auth/auth.go index 315323c..da526bc 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "io" "os" @@ -97,15 +98,29 @@ func EncryptAndSave(auth *CachedAuth, username string, secretKey *[32]byte) erro return saveAuths(auths) } -func authenticate(c *protonmail.Client, CachedAuth *CachedAuth) (openpgp.EntityList, error) { - auth, err := c.AuthRefresh(&CachedAuth.Auth) - if err != nil { - // TODO: handle expired token, re-authenticate +func authenticate(c *protonmail.Client, cachedAuth *CachedAuth, username string) (openpgp.EntityList, error) { + auth, err := c.AuthRefresh(&cachedAuth.Auth) + if apiErr, ok := err.(*protonmail.ApiError); ok && apiErr.Code == 10013 { + // Invalid refresh token, re-authenticate + authInfo, err := c.AuthInfo(username) + if err != nil { + return nil, fmt.Errorf("cannot re-authenticate: failed to get auth info: %v", err) + } + + if authInfo.TwoFactor == 1 { + return nil, fmt.Errorf("cannot re-authenticate: two factor authentication enabled, please login manually") + } + + auth, err = c.Auth(username, cachedAuth.LoginPassword, "", authInfo) + if err != nil { + return nil, fmt.Errorf("cannot re-authenticate: %v", err) + } + } else if err != nil { return nil, err } - CachedAuth.Auth = *auth + cachedAuth.Auth = *auth - return c.Unlock(auth, CachedAuth.MailboxPassword) + return c.Unlock(auth, cachedAuth.MailboxPassword) } func GeneratePassword() (secretKey *[32]byte, password string, err error) { @@ -168,14 +183,14 @@ func (m *Manager) Auth(username, password string) (*protonmail.Client, openpgp.E c := m.newClient() c.ReAuth = func() error { - if _, err := authenticate(c, &cachedAuth); err != nil { + if _, err := authenticate(c, &cachedAuth, username); err != nil { return err } return EncryptAndSave(&cachedAuth, username, &secretKey) } // authenticate updates cachedAuth with the new refresh token - privateKeys, err := authenticate(c, &cachedAuth) + privateKeys, err := authenticate(c, &cachedAuth, username) if err != nil { return nil, nil, err }