diff --git a/protonmail/auth.go b/protonmail/auth.go index 1e57210..6d804da 100644 --- a/protonmail/auth.go +++ b/protonmail/auth.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "errors" "io/ioutil" + "log" "net/http" "strings" "time" @@ -193,6 +194,30 @@ func (c *Client) AuthRefresh(expiredAuth *Auth) (*Auth, error) { return auth, nil } +func unlockKey(e *openpgp.Entity, passphraseBytes []byte) error { + var privateKeys []*packet.PrivateKey + + // e.PrivateKey is a signing key + if e.PrivateKey != nil { + privateKeys = append(privateKeys, e.PrivateKey) + } + + // e.Subkeys are encryption keys + for _, subkey := range e.Subkeys { + if subkey.PrivateKey != nil { + privateKeys = append(privateKeys, subkey.PrivateKey) + } + } + + for _, priv := range privateKeys { + if err := priv.Decrypt(passphraseBytes); err != nil { + return err + } + } + + return nil +} + func (c *Client) Unlock(auth *Auth, passphrase string) (openpgp.EntityList, error) { passphraseBytes := []byte(passphrase) if auth.keySalt != "" { @@ -218,24 +243,8 @@ func (c *Client) Unlock(auth *Auth, passphrase string) (openpgp.EntityList, erro } for _, e := range keyRing { - var privateKeys []*packet.PrivateKey - - // e.PrivateKey is a signing key - if e.PrivateKey != nil { - privateKeys = append(privateKeys, e.PrivateKey) - } - - // e.Subkeys are encryption keys - for _, subkey := range e.Subkeys { - if subkey.PrivateKey != nil { - privateKeys = append(privateKeys, subkey.PrivateKey) - } - } - - for _, priv := range privateKeys { - if err := priv.Decrypt(passphraseBytes); err != nil { - return nil, err - } + if err := unlockKey(e, passphraseBytes); err != nil { + return nil, err } } @@ -261,50 +270,37 @@ func (c *Client) Unlock(auth *Auth, passphrase string) (openpgp.EntityList, erro c.accessToken = string(accessTokenBytes) c.keyRing = keyRing - // Get additional private keys - u, err := c.GetCurrentUser() + // Unlock additional private keys + addrs, err := c.ListAddresses() if err != nil { return nil, err } - var keyRing2 openpgp.EntityList - for _, e := range u.Addresses { - for _, f := range e.Keys { - pkey, err := f.Entity() + + for _, addr := range addrs { + for _, key := range addr.Keys { + entity, err := key.Entity() if err != nil { return nil, err } - if pkey.PrimaryKey.KeyId == keyRing[0].PrimaryKey.KeyId { - continue - } - prKey, err := openpgp.ReadArmoredKeyRing(strings.NewReader(f.PrivateKey)) - // TODO Are these errors fatal? - if err != nil { - continue - } - if len(prKey) == 0 { - continue - } - keyRing2 = append(keyRing2, prKey[0]) - } - } - for _, e := range keyRing2 { - var privateKeys []*packet.PrivateKey - if e.PrivateKey != nil { - privateKeys = append(privateKeys, e.PrivateKey) - } - for _, subkey := range e.Subkeys { - if subkey.PrivateKey != nil { - privateKeys = append(privateKeys, subkey.PrivateKey) + found := false + for _, e := range keyRing { + if e.PrimaryKey.KeyIdString() == entity.PrimaryKey.KeyIdString() { + found = true + break + } + } + if found { + continue } - } - for _, priv := range privateKeys { - if err := priv.Decrypt(passphraseBytes); err != nil { - return nil, err + if err := unlockKey(entity, passphraseBytes); err != nil { + log.Printf("failed to unlock key %v: %v", entity.PrimaryKey.KeyIdString(), err) + continue } + + keyRing = append(keyRing, entity) } - keyRing = append(keyRing, e) } return keyRing, nil diff --git a/protonmail/keys.go b/protonmail/keys.go index ee5ceb7..3e72257 100644 --- a/protonmail/keys.go +++ b/protonmail/keys.go @@ -9,13 +9,21 @@ import ( "golang.org/x/crypto/openpgp" ) +type PrivateKeyFlags int + +const ( + PrivateKeyVerify PrivateKeyFlags = 1 + PrivateKeyEncrypt = 2 +) + type PrivateKey struct { ID string Version int - PublicKey string + Flags PrivateKeyFlags PrivateKey string Fingerprint string Activation interface{} // TODO + Primary int } func (priv *PrivateKey) Entity() (*openpgp.Entity, error) {