diff --git a/cmd/hydroxide/hydroxide.go b/cmd/hydroxide/hydroxide.go index 1e21b2d..afbe5d9 100644 --- a/cmd/hydroxide/hydroxide.go +++ b/cmd/hydroxide/hydroxide.go @@ -31,8 +31,10 @@ func main() { scanner.Scan() code := scanner.Text() - err := c.Auth(username, password, code, nil) + auth, err := c.Auth(username, password, code, nil) if err != nil { log.Fatal(err) } + + log.Println(auth) } diff --git a/protonmail/auth.go b/protonmail/auth.go index da806b2..8692d19 100644 --- a/protonmail/auth.go +++ b/protonmail/auth.go @@ -79,8 +79,7 @@ const ( PasswordTwo = 2 ) -type authResp struct { - resp +type Auth struct { AccessToken string ExpiresIn int TokenType string @@ -88,17 +87,32 @@ type authResp struct { UID string `json:"Uid"` RefreshToken string EventID string - ServerProof string PasswordMode PasswordMode + + privateKey string + keySalt string +} + +type authResp struct { + resp + Auth + ServerProof string PrivateKey string KeySalt string } -func (c *Client) Auth(username, password, twoFactorCode string, info *AuthInfo) error { +func (resp *authResp) auth() *Auth { + auth := &resp.Auth + auth.privateKey = resp.PrivateKey + auth.keySalt = resp.KeySalt + return auth +} + +func (c *Client) Auth(username, password, twoFactorCode string, info *AuthInfo) (*Auth, error) { if info == nil { var err error if info, err = c.AuthInfo(username); err != nil { - return err + return nil, err } } @@ -106,7 +120,7 @@ func (c *Client) Auth(username, password, twoFactorCode string, info *AuthInfo) proofs, err := srp([]byte(password), info) if err != nil { - return err + return nil, err } reqData := &authReq{ @@ -118,23 +132,22 @@ func (c *Client) Auth(username, password, twoFactorCode string, info *AuthInfo) ClientProof: base64.StdEncoding.EncodeToString(proofs.clientProof), TwoFactorCode: twoFactorCode, } - log.Printf("%#v\n", reqData) req, err := c.newJSONRequest(http.MethodPost, "/auth", reqData) if err != nil { - return err + return nil, err } var respData authResp if err := c.doJSON(req, &respData); err != nil { - return err + return nil, err } - log.Printf("%#v\n", respData) + log.Printf("%+v\n", respData) if err := proofs.VerifyServerProof(respData.ServerProof); err != nil { - return err + return nil, err } - return nil + return respData.auth(), nil } diff --git a/protonmail/password.go b/protonmail/password.go new file mode 100644 index 0000000..c3d50bb --- /dev/null +++ b/protonmail/password.go @@ -0,0 +1,63 @@ +package protonmail + +import ( + "bytes" + "crypto/sha512" + "errors" + + "github.com/emersion/go-bcrypt" +) + +const bcryptCost = 10 + +func hashBcrypt(password, salt []byte) ([]byte, error) { + hashed, err := bcrypt.GenerateFromPasswordAndSalt(password, bcryptCost, salt) + if err != nil { + return nil, err + } + hashed = bytes.Replace(hashed, []byte("$2a$"), []byte("$2y$"), 1) + return hashed, nil +} + +func expandHash(b []byte) []byte { + var expanded []byte + var part [64]byte + + part = sha512.Sum512(append(b, 0)) + expanded = append(expanded, part[:]...) + + part = sha512.Sum512(append(b, 1)) + expanded = append(expanded, part[:]...) + + part = sha512.Sum512(append(b, 2)) + expanded = append(expanded, part[:]...) + + part = sha512.Sum512(append(b, 3)) + expanded = append(expanded, part[:]...) + + return expanded +} + +func hashPassword(version int, password, salt, modulus []byte) ([]byte, error) { + switch version { + case 3, 4: + salt = append(salt, []byte("proton")...) + hashed, err := hashBcrypt(password, salt) + if err != nil { + return nil, err + } + return expandHash(append([]byte(hashed), modulus...)), nil + default: + return nil, errors.New("unsupported auth version") + } +} + +func computeKeyPassword(password, salt []byte) ([]byte, error) { + hashed, err := hashBcrypt(password, salt) + if err != nil { + return nil, err + } + + // Remove bcrypt prefix and salt (first 29 characters) + return hashed[29:], nil +} diff --git a/protonmail/protonmail.go b/protonmail/protonmail.go index cdf9854..e355075 100644 --- a/protonmail/protonmail.go +++ b/protonmail/protonmail.go @@ -63,7 +63,12 @@ func (c *Client) newJSONRequest(method, path string, body interface{}) (*http.Re if err := json.NewEncoder(&b).Encode(body); err != nil { return nil, err } - return c.newRequest(method, path, &b) + req, err := c.newRequest(method, path, &b) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + return req, nil } func (c *Client) do(req *http.Request) (*http.Response, error) { @@ -75,6 +80,8 @@ func (c *Client) do(req *http.Request) (*http.Response, error) { } func (c *Client) doJSON(req *http.Request, respData interface{}) error { + req.Header.Set("Accept", "application/json") + resp, err := c.do(req) if err != nil { return err diff --git a/protonmail/srp.go b/protonmail/srp.go index eddd775..6bd9d29 100644 --- a/protonmail/srp.go +++ b/protonmail/srp.go @@ -3,14 +3,12 @@ package protonmail import ( "bytes" "crypto/rand" - "crypto/sha512" "crypto/subtle" "encoding/base64" "errors" "io" "math/big" - "github.com/emersion/go-bcrypt" "golang.org/x/crypto/openpgp" "golang.org/x/crypto/openpgp/clearsign" openpgperrors "golang.org/x/crypto/openpgp/errors" @@ -33,41 +31,6 @@ func decodeModulus(msg string) ([]byte, error) { return base64.StdEncoding.DecodeString(string(block.Plaintext)) } -func expandHash(b []byte) []byte { - var expanded []byte - var part [64]byte - - part = sha512.Sum512(append(b, 0)) - expanded = append(expanded, part[:]...) - - part = sha512.Sum512(append(b, 1)) - expanded = append(expanded, part[:]...) - - part = sha512.Sum512(append(b, 2)) - expanded = append(expanded, part[:]...) - - part = sha512.Sum512(append(b, 3)) - expanded = append(expanded, part[:]...) - - return expanded -} - -func hashPassword(version int, password, salt, modulus []byte) ([]byte, error) { - switch version { - case 3, 4: - salt = append(salt, []byte("proton")...) - hashed, err := bcrypt.GenerateFromPasswordAndSalt(password, 10, salt) - if err != nil { - return nil, err - } - - hashed = bytes.Replace(hashed, []byte("$2a$"), []byte("$2y$"), 1) - return expandHash(append([]byte(hashed), modulus...)), nil - default: - return nil, errors.New("unsupported auth version") - } -} - func reverse(b []byte) { for i := 0; i < len(b)/2; i++ { j := len(b) - 1 - i