diff --git a/carddav/carddav.go b/carddav/carddav.go index dd8c3a0..fe2913b 100644 --- a/carddav/carddav.go +++ b/carddav/carddav.go @@ -335,7 +335,7 @@ func (ab *addressBook) CreateAddressObject(card vcard.Card) (carddav.AddressObje func (ab *addressBook) receiveEvents(events <-chan *protonmail.Event) { for event := range events { ab.locker.Lock() - if event.Refresh == 1 { + if event.Refresh&protonmail.EventRefreshContacts != 0 { ab.cache = make(map[string]*addressObject) ab.total = -1 } else if len(event.Contacts) > 0 { diff --git a/cmd/hydroxide/hydroxide.go b/cmd/hydroxide/hydroxide.go index 931a678..7274358 100644 --- a/cmd/hydroxide/hydroxide.go +++ b/cmd/hydroxide/hydroxide.go @@ -10,12 +10,17 @@ import ( "os" "time" + imapmove "github.com/emersion/go-imap-move" + imapserver "github.com/emersion/go-imap/server" + imapspacialuse "github.com/emersion/go-imap-specialuse" "github.com/emersion/go-smtp" "github.com/howeyc/gopass" "github.com/emersion/hydroxide/auth" "github.com/emersion/hydroxide/carddav" + "github.com/emersion/hydroxide/events" "github.com/emersion/hydroxide/protonmail" + imapbackend "github.com/emersion/hydroxide/imap" smtpbackend "github.com/emersion/hydroxide/smtp" ) @@ -28,10 +33,11 @@ func newClient() *protonmail.Client { } } -func receiveEvents(c *protonmail.Client, last string, ch chan<- *protonmail.Event) { +func receiveEvents(c *protonmail.Client, ch chan<- *protonmail.Event) { t := time.NewTicker(time.Minute) defer t.Stop() + var last string for range t.C { event, err := c.GetEvent(last) if err != nil { @@ -135,7 +141,7 @@ func main() { case "smtp": port := os.Getenv("PORT") if port == "" { - port = "1465" + port = "1025" } sessions := auth.NewManager(newClient) @@ -149,6 +155,25 @@ func main() { log.Println("Starting SMTP server at", s.Addr) log.Fatal(s.ListenAndServe()) + case "imap": + port := os.Getenv("PORT") + if port == "" { + port = "1143" + } + + sessions := auth.NewManager(newClient) + eventsManager := events.NewManager() + + be := imapbackend.New(sessions, eventsManager) + s := imapserver.New(be) + s.Addr = "127.0.0.1:" + port + s.AllowInsecureAuth = true // TODO: remove this + //s.Debug = os.Stdout + s.Enable(imapspacialuse.NewExtension()) + s.Enable(imapmove.NewExtension()) + + log.Println("Starting IMAP server at", s.Addr) + log.Fatal(s.ListenAndServe()) case "carddav": port := os.Getenv("PORT") if port == "" { @@ -156,6 +181,7 @@ func main() { } sessions := auth.NewManager(newClient) + eventsManager := events.NewManager() handlers := make(map[string]http.Handler) s := &http.Server{ @@ -183,9 +209,9 @@ func main() { h, ok := handlers[username] if !ok { - events := make(chan *protonmail.Event) - go receiveEvents(c, "", events) - h = carddav.NewHandler(c, privateKeys, events) + ch := make(chan *protonmail.Event) + eventsManager.Register(c, username, ch, nil) + h = carddav.NewHandler(c, privateKeys, ch) handlers[username] = h } diff --git a/events/events.go b/events/events.go new file mode 100644 index 0000000..9a1505e --- /dev/null +++ b/events/events.go @@ -0,0 +1,112 @@ +package events + +import ( + "log" + "sync" + "time" + + "github.com/emersion/hydroxide/protonmail" +) + +const pollInterval = 30 * time.Second + +type Receiver struct { + c *protonmail.Client + + locker sync.Mutex + channels []chan<- *protonmail.Event + + poll chan struct{} +} + +func (r *Receiver) receiveEvents() { + t := time.NewTicker(pollInterval) + defer t.Stop() + + var last string + for { + event, err := r.c.GetEvent(last) + if err != nil { + log.Println("cannot receive event:", err) + continue + } + last = event.ID + + r.locker.Lock() + n := len(r.channels) + for _, ch := range r.channels { + ch <- event + } + r.locker.Unlock() + + if n == 0 { + break + } + + select { + case <-t.C: + case <-r.poll: + } + } +} + +func (r *Receiver) Poll() { + r.poll <- struct{}{} +} + +type Manager struct { + receivers map[string]*Receiver + locker sync.Mutex +} + +func NewManager() *Manager { + return &Manager{ + receivers: make(map[string]*Receiver), + } +} + +func (m *Manager) Register(c *protonmail.Client, username string, ch chan<- *protonmail.Event, done <-chan struct{}) *Receiver { + m.locker.Lock() + defer m.locker.Unlock() + + r, ok := m.receivers[username] + if ok { + r.locker.Lock() + r.channels = append(r.channels, ch) + r.locker.Unlock() + } else { + r = &Receiver{ + c: c, + channels: []chan<- *protonmail.Event{ch}, + poll: make(chan struct{}), + } + + go func() { + r.receiveEvents() + + m.locker.Lock() + delete(m.receivers, username) + m.locker.Unlock() + }() + + m.receivers[username] = r + } + + if done != nil { + go func() { + <-done + + r.locker.Lock() + for i, c := range r.channels { + if c == ch { + r.channels = append(r.channels[:i], r.channels[i+1:]...) + } + } + r.locker.Unlock() + + close(ch) + }() + } + + return r +} diff --git a/imap/backend.go b/imap/backend.go new file mode 100644 index 0000000..69e8cb2 --- /dev/null +++ b/imap/backend.go @@ -0,0 +1,42 @@ +package imap + +import ( + "errors" + + imapbackend "github.com/emersion/go-imap/backend" + + "github.com/emersion/hydroxide/auth" + "github.com/emersion/hydroxide/events" +) + +var errNotYetImplemented = errors.New("not yet implemented") + +type backend struct { + sessions *auth.Manager + eventsManager *events.Manager + updates chan interface{} +} + +func (be *backend) Login(username, password string) (imapbackend.User, error) { + c, privateKeys, err := be.sessions.Auth(username, password) + if err != nil { + return nil, err + } + + u, err := c.GetCurrentUser() + if err != nil { + return nil, err + } + + // TODO: decrypt private keys in u.Addresses + + return newUser(be, c, u, privateKeys) +} + +func (be *backend) Updates() <-chan interface{} { + return be.updates +} + +func New(sessions *auth.Manager, eventsManager *events.Manager) imapbackend.Backend { + return &backend{sessions, eventsManager, make(chan interface{}, 50)} +} diff --git a/imap/database/mailbox.go b/imap/database/mailbox.go new file mode 100644 index 0000000..ccd3e4c --- /dev/null +++ b/imap/database/mailbox.go @@ -0,0 +1,197 @@ +package database + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/boltdb/bolt" + + "github.com/emersion/hydroxide/protonmail" +) + +func serializeUID(uid uint32) []byte { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, uid) + return b +} + +func unserializeUID(b []byte) uint32 { + return binary.BigEndian.Uint32(b) +} + +func mailboxCreateMessage(b *bolt.Bucket, apiID string) (seqNum uint32, err error) { + want := []byte(apiID) + c := b.Cursor() + var n uint32 = 1 + for k, v := c.First(); k != nil; k, v = c.Next() { + if bytes.Equal(v, want) { + return n, nil + } + n++ + } + + id, _ := b.NextSequence() + uid := uint32(id) + return n, b.Put(serializeUID(uid), want) +} + +func mailboxDeleteMessage(b *bolt.Bucket, apiID string) (seqNum uint32, err error) { + want := []byte(apiID) + c := b.Cursor() + var n uint32 = 1 + for k, v := c.First(); k != nil; k, v = c.Next() { + if bytes.Equal(v, want) { + return n, b.Delete(k) + } + n++ + } + return 0, nil +} + +type Mailbox struct { + labelID string + u *User +} + +func (mbox *Mailbox) bucket(tx *bolt.Tx) (*bolt.Bucket, error) { + b := tx.Bucket(mailboxesBucket) + if b == nil { + return nil, errors.New("cannot find mailboxes bucket") + } + b = b.Bucket([]byte(mbox.labelID)) + if b == nil { + return nil, errors.New("cannot find mailbox bucket") + } + return b, nil +} + +func (mbox *Mailbox) Sync(messages []*protonmail.Message) error { + return mbox.u.db.Update(func(tx *bolt.Tx) error { + b, err := mbox.bucket(tx) + if err != nil { + return err + } + + for _, msg := range messages { + if _, err := mailboxCreateMessage(b, msg.ID); err != nil { + return err + } + } + + return userSync(tx, messages) + }) +} + +func (mbox *Mailbox) UidNext() (uint32, error) { + var uid uint32 + err := mbox.u.db.View(func(tx *bolt.Tx) error { + b, err := mbox.bucket(tx) + if err != nil { + return err + } + + uid = uint32(b.Sequence() + 1) + return nil + }) + return uid, err +} + +func (mbox *Mailbox) FromUid(uid uint32) (apiID string, err error) { + err = mbox.u.db.View(func(tx *bolt.Tx) error { + b, err := mbox.bucket(tx) + if err != nil { + return err + } + + k := serializeUID(uid) + v := b.Get(k) + if v == nil { + return ErrNotFound + } + apiID = string(v) + return nil + }) + return +} + +func (mbox *Mailbox) FromSeqNum(seqNum uint32) (apiID string, err error) { + err = mbox.u.db.View(func(tx *bolt.Tx) error { + b, err := mbox.bucket(tx) + if err != nil { + return err + } + + c := b.Cursor() + var n uint32 = 1 + for k, v := c.First(); k != nil; k, v = c.Next() { + if seqNum == n { + apiID = string(v) + return nil + } + n++ + } + + return ErrNotFound + }) + return +} + +func (mbox *Mailbox) FromApiID(apiID string) (seqNum uint32, uid uint32, err error) { + err = mbox.u.db.View(func(tx *bolt.Tx) error { + b, err := mbox.bucket(tx) + if err != nil { + return err + } + + want := []byte(apiID) + c := b.Cursor() + var n uint32 = 1 + for k, v := c.First(); k != nil; k, v = c.Next() { + if bytes.Equal(v, want) { + seqNum = n + uid = unserializeUID(k) + return nil + } + n++ + } + + return ErrNotFound + }) + return +} + +func (mbox *Mailbox) ForEach(f func(seqNum, uid uint32, apiID string) error) error { + return mbox.u.db.View(func(tx *bolt.Tx) error { + b, err := mbox.bucket(tx) + if err != nil { + return err + } + + c := b.Cursor() + var n uint32 = 1 + for k, v := c.First(); k != nil; k, v = c.Next() { + if err := f(n, unserializeUID(k), string(v)); err != nil { + return err + } + n++ + } + + return nil + }) +} + +func (mbox *Mailbox) Reset() error { + return mbox.u.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(mailboxesBucket) + if b == nil { + return errors.New("cannot find mailboxes bucket") + } + k := []byte(mbox.labelID) + if err := b.DeleteBucket(k); err != nil { + return err + } + _, err := b.CreateBucket(k) + return err + }) +} diff --git a/imap/database/user.go b/imap/database/user.go new file mode 100644 index 0000000..e93d151 --- /dev/null +++ b/imap/database/user.go @@ -0,0 +1,233 @@ +package database + +import ( + "encoding/json" + "errors" + + "github.com/boltdb/bolt" + + "github.com/emersion/hydroxide/protonmail" +) + +var ErrNotFound = errors.New("message not found in local database") + +var ( + mailboxesBucket = []byte("mailboxes") + messagesBucket = []byte("messages") +) + +func userMessage(b *bolt.Bucket, apiID string) (*protonmail.Message, error) { + k := []byte(apiID) + v := b.Get(k) + if v == nil { + return nil, ErrNotFound + } + + msg := &protonmail.Message{} + err := json.Unmarshal(v, msg) + return msg, err +} + +func userCreateMessage(b *bolt.Bucket, msg *protonmail.Message) error { + k := []byte(msg.ID) + v, err := json.Marshal(msg) + if err != nil { + return err + } + return b.Put(k, v) +} + +func userSync(tx *bolt.Tx, messages []*protonmail.Message) error { + b, err := tx.CreateBucketIfNotExists(messagesBucket) + if err != nil { + return err + } + + for _, msg := range messages { + if err := userCreateMessage(b, msg); err != nil { + return err + } + } + + return nil +} + +type User struct { + db *bolt.DB +} + +func (u *User) Mailbox(labelID string) (*Mailbox, error) { + err := u.db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists(mailboxesBucket) + if err != nil { + return err + } + _, err = b.CreateBucketIfNotExists([]byte(labelID)) + return err + }) + if err != nil { + return nil, err + } + + return &Mailbox{labelID, u}, nil +} + +func (u *User) Message(apiID string) (*protonmail.Message, error) { + var msg *protonmail.Message + err := u.db.View(func (tx *bolt.Tx) error { + b := tx.Bucket(messagesBucket) + if b == nil { + return ErrNotFound + } + + var err error + msg, err = userMessage(b, apiID) + return err + }) + return msg, err +} + +func (u *User) ResetMessages() error { + return u.db.Update(func(tx *bolt.Tx) error { + return tx.DeleteBucket(messagesBucket) + }) +} + +func (u *User) CreateMessage(msg *protonmail.Message) (seqNums map[string]uint32, err error) { + seqNums = make(map[string]uint32) + err = u.db.Update(func(tx *bolt.Tx) error { + messages, err := tx.CreateBucketIfNotExists(messagesBucket) + if err != nil { + return err + } + + if err := userCreateMessage(messages, msg); err != nil { + return err + } + + mailboxes, err := tx.CreateBucketIfNotExists(mailboxesBucket) + if err != nil { + return err + } + for _, labelID := range msg.LabelIDs { + mbox, err := mailboxes.CreateBucketIfNotExists([]byte(labelID)) + if err != nil { + return err + } + + seqNum, err := mailboxCreateMessage(mbox, msg.ID) + if err != nil { + return err + } + seqNums[labelID] = seqNum + } + + return nil + }) + return +} + +func (u *User) UpdateMessage(apiID string, update *protonmail.EventMessageUpdate) (createdSeqNums map[string]uint32, deletedSeqNums map[string]uint32, err error) { + createdSeqNums = make(map[string]uint32) + deletedSeqNums = make(map[string]uint32) + err = u.db.Update(func(tx *bolt.Tx) error { + messages := tx.Bucket(messagesBucket) + if messages == nil { + return errors.New("cannot update message in local DB: messages bucket doesn't exist") + } + + msg, err := userMessage(messages, apiID) + if err != nil { + return err + } + + addedLabels, removedLabels := update.DiffLabelIDs(msg.LabelIDs) + + mailboxes, err := tx.CreateBucketIfNotExists(mailboxesBucket) + if err != nil { + return err + } + for _, labelID := range addedLabels { + mbox, err := mailboxes.CreateBucketIfNotExists([]byte(labelID)) + if err != nil { + return err + } + + seqNum, err := mailboxCreateMessage(mbox, apiID) + if err != nil { + return err + } + createdSeqNums[labelID] = seqNum + } + for _, labelID := range removedLabels { + mbox := mailboxes.Bucket([]byte(labelID)) + if mbox == nil { + continue + } + + seqNum, err := mailboxDeleteMessage(mbox, apiID) + if err != nil { + return err + } + deletedSeqNums[labelID] = seqNum + } + + update.Patch(msg) + return userCreateMessage(messages, msg) + }) + return +} + +func (u *User) DeleteMessage(apiID string) (seqNums map[string]uint32, err error) { + seqNums = make(map[string]uint32) + err = u.db.Update(func(tx *bolt.Tx) error { + messages:= tx.Bucket(messagesBucket) + if messages == nil { + return nil + } + + msg, err := userMessage(messages, apiID) + if err == ErrNotFound { + return nil + } else if err != nil { + return err + } + + if err := messages.Delete([]byte(apiID)); err != nil { + return err + } + + mailboxes := tx.Bucket(mailboxesBucket) + if mailboxes == nil { + return nil + } + for _, labelID := range msg.LabelIDs { + mbox := mailboxes.Bucket([]byte(labelID)) + if mbox == nil { + continue + } + + seqNum, err := mailboxDeleteMessage(mbox, msg.ID) + if err != nil { + return err + } + seqNums[labelID] = seqNum + } + + return nil + }) + return +} + +func (u *User) Close() error { + return u.db.Close() +} + +func Open(path string) (*User, error) { + db, err := bolt.Open(path, 0700, nil) + if err != nil { + return nil, err + } + + return &User{db}, nil +} diff --git a/imap/mailbox.go b/imap/mailbox.go new file mode 100644 index 0000000..6a4c3df --- /dev/null +++ b/imap/mailbox.go @@ -0,0 +1,510 @@ +package imap + +import ( + "errors" + "log" + "strings" + "sync" + "time" + + "github.com/emersion/go-imap" + imapbackend "github.com/emersion/go-imap/backend" + + "github.com/emersion/hydroxide/imap/database" + "github.com/emersion/hydroxide/protonmail" +) + +const delimiter = "/" + +type mailbox struct { + name string + label string + flags []string + + u *user + db *database.Mailbox + + initialized bool + initializedLock sync.Mutex + + total, unread int + deleted map[string]struct{} +} + +func (mbox *mailbox) Name() string { + return mbox.name +} + +func (mbox *mailbox) Info() (*imap.MailboxInfo, error) { + return &imap.MailboxInfo{ + Attributes: append(mbox.flags, imap.NoInferiorsAttr), + Delimiter: delimiter, + Name: mbox.name, + }, nil +} + +func (mbox *mailbox) Status(items []imap.StatusItem) (*imap.MailboxStatus, error) { + status := imap.NewMailboxStatus(mbox.name, items) + status.Flags = mbox.flags + status.PermanentFlags = []string{imap.SeenFlag, imap.FlaggedFlag, imap.DeletedFlag} + status.UnseenSeqNum = 0 // TODO + + for _, name := range items { + switch name { + case imap.StatusMessages: + status.Messages = uint32(mbox.total) + case imap.StatusUidNext: + uidNext, err := mbox.db.UidNext() + if err != nil { + return nil, err + } + status.UidNext = uidNext + case imap.StatusUidValidity: + status.UidValidity = 1 + case imap.StatusRecent: + status.Recent = 0 + case imap.StatusUnseen: + status.Unseen = uint32(mbox.unread) + } + } + + return status, nil +} + +func (mbox *mailbox) SetSubscribed(subscribed bool) error { + return errNotYetImplemented // TODO +} + +func (mbox *mailbox) Check() error { + return nil +} + +func (mbox *mailbox) sync() error { + log.Printf("Synchronizing mailbox %v...", mbox.name) + + // TODO: don't do this without incrementing UIDVALIDITY + if err := mbox.db.Reset(); err != nil { + return err + } + + filter := &protonmail.MessageFilter{ + PageSize: 150, + Label: mbox.label, + Sort: "ID", + Asc: true, + } + + total := -1 + for { + offset := filter.PageSize * filter.Page + if total >= 0 && offset > total { + break + } + + var page []*protonmail.Message + var err error + total, page, err = mbox.u.c.ListMessages(filter) + if err != nil { + return err + } + + if err := mbox.db.Sync(page); err != nil { + return err + } + + filter.Page++ + } + + log.Printf("Synchronizing mailbox %v: done.", mbox.name) + + return nil +} + +func (mbox *mailbox) init() error { + mbox.initializedLock.Lock() + defer mbox.initializedLock.Unlock() + + if mbox.initialized { + return nil + } + + // TODO: sync only the first time + if err := mbox.sync(); err != nil { + return err + } + + mbox.initialized = true + return nil +} + +func (mbox *mailbox) reset() error { + mbox.initializedLock.Lock() + defer mbox.initializedLock.Unlock() + + mbox.initialized = false + + return mbox.db.Reset() +} + +func (mbox *mailbox) fetchFlags(msg *protonmail.Message) []string { + flags := fetchFlags(msg) + if _, ok := mbox.deleted[msg.ID]; ok { + flags = append(flags, imap.DeletedFlag) + } + return flags +} + +func (mbox *mailbox) fetchMessage(isUid bool, id uint32, items []imap.FetchItem) (*imap.Message, error) { + var apiID string + var err error + if isUid { + apiID, err = mbox.db.FromUid(id) + } else { + apiID, err = mbox.db.FromSeqNum(id) + } + if err != nil { + return nil, err + } + + seqNum, uid, err := mbox.db.FromApiID(apiID) + if err != nil { + return nil, err + } + + msg, err := mbox.u.db.Message(apiID) + if err != nil { + return nil, err + } + + fetched := imap.NewMessage(seqNum, items) + for _, item := range items { + switch item { + case imap.FetchEnvelope: + fetched.Envelope = fetchEnvelope(msg) + case imap.FetchBody, imap.FetchBodyStructure: + bs, err := mbox.fetchBodyStructure(msg, item == imap.FetchBodyStructure) + if err != nil { + return nil, err + } + fetched.BodyStructure = bs + case imap.FetchFlags: + fetched.Flags = mbox.fetchFlags(msg) + case imap.FetchInternalDate: + fetched.InternalDate = time.Unix(msg.Time, 0) + case imap.FetchRFC822Size: + fetched.Size = uint32(msg.Size) + case imap.FetchUid: + fetched.Uid = uid + default: + section, err := imap.ParseBodySectionName(item) + if err != nil { + break + } + + l, err := mbox.fetchBodySection(msg, section) + if err != nil { + return nil, err + } + fetched.Body[section] = l + } + } + + return fetched, nil +} + +func (mbox *mailbox) ListMessages(uid bool, seqSet *imap.SeqSet, items []imap.FetchItem, ch chan<- *imap.Message) error { + defer close(ch) + + if err := mbox.init(); err != nil { + return err + } + + for _, seq := range seqSet.Set { + start := seq.Start + if start == 0 { + start = 1 + } + + stop := seq.Stop + if stop == 0 { + if uid { + uidNext, err := mbox.db.UidNext() + if err != nil { + return err + } + stop = uidNext - 1 + } else { + stop = uint32(mbox.total) + } + } + + for i := start; i <= stop; i++ { + msg, err := mbox.fetchMessage(uid, i, items) + if err == database.ErrNotFound { + continue + } else if err != nil { + return err + } + if msg != nil { + ch <- msg + } + } + } + + return nil +} + +func matchString(s, substr string) bool { + return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) +} + +func (mbox *mailbox) SearchMessages(isUID bool, c *imap.SearchCriteria) ([]uint32, error) { + if err := mbox.init(); err != nil { + return nil, err + } + + // TODO: c.Not, c.Or + if c.Not != nil || c.Or != nil { + return nil, errors.New("search queries with NOT or OR clauses or not yet implemented") + } + + var results []uint32 + err := mbox.db.ForEach(func(seqNum, uid uint32, apiID string) error { + if c.SeqNum != nil && !c.SeqNum.Contains(seqNum) { + return nil + } + if c.Uid != nil && !c.Uid.Contains(uid) { + return nil + } + + // TODO: fetch message from local DB only if needed + msg, err := mbox.u.db.Message(apiID) + if err != nil { + return err + } + + flags := make(map[string]bool) + for _, flag := range mbox.fetchFlags(msg) { + flags[flag] = true + } + for _, f := range c.WithFlags { + if !flags[f] { + return nil + } + } + for _, f := range c.WithoutFlags { + if flags[f] { + return nil + } + } + + date := time.Unix(msg.Time, 0).Round(24 * time.Hour) + if !c.Since.IsZero() && !date.After(c.Since) { + return nil + } + if !c.Before.IsZero() && !date.Before(c.Before) { + return nil + } + // TODO: this date should be from the Date MIME header + if !c.SentBefore.IsZero() && !date.Before(c.SentBefore) { + return nil + } + if !c.SentSince.IsZero() && !date.After(c.SentSince) { + return nil + } + + h := messageHeader(msg) + for key, wantValues := range c.Header { + values, ok := h[key] + for _, wantValue := range wantValues { + if wantValue == "" && !ok { + return nil + } + if wantValue != "" { + ok := false + for _, v := range values { + if matchString(v, wantValue) { + ok = true + break + } + } + if !ok { + return nil + } + } + } + } + + // TODO: c.Body, c.Text + + if c.Larger > 0 && uint32(msg.Size) < c.Larger { + return nil + } + if c.Smaller > 0 && uint32(msg.Size) > c.Smaller { + return nil + } + + if isUID { + results = append(results, uid) + } else { + results = append(results, seqNum) + } + return nil + }) + if err != nil { + return nil, err + } + return results, nil +} + +func (mbox *mailbox) CreateMessage(flags []string, date time.Time, body imap.Literal) error { + if mbox.label != protonmail.LabelDraft { + return errors.New("cannot create messages outside the Drafts mailbox") + } + + if err := mbox.init(); err != nil { + return err + } + + _, err := createMessage(mbox.u.c, mbox.u.u, mbox.u.privateKeys, body) + if err != nil { + return err + } + + return mbox.Poll() +} + +func (mbox *mailbox) fromSeqSet(isUID bool, seqSet *imap.SeqSet) ([]string, error) { + var apiIDs []string + err := mbox.db.ForEach(func(seqNum, uid uint32, apiID string) error { + var id uint32 + if isUID { + id = uid + } else { + id = seqNum + } + + if seqSet.Contains(id) { + apiIDs = append(apiIDs, apiID) + } + return nil + }) + return apiIDs, err +} + +func (mbox *mailbox) UpdateMessagesFlags(uid bool, seqSet *imap.SeqSet, op imap.FlagsOp, flags []string) error { + if err := mbox.init(); err != nil { + return err + } + + apiIDs, err := mbox.fromSeqSet(uid, seqSet) + if err != nil { + return err + } + + // TODO: imap.SetFlags should remove currently set flags + + for _, flag := range flags { + var err error + switch flag { + case imap.SeenFlag: + switch op { + case imap.SetFlags, imap.AddFlags: + err = mbox.u.c.MarkMessagesRead(apiIDs) + case imap.RemoveFlags: + err = mbox.u.c.MarkMessagesUnread(apiIDs) + } + case imap.FlaggedFlag: + switch op { + case imap.SetFlags, imap.AddFlags: + err = mbox.u.c.LabelMessages(protonmail.LabelStarred, apiIDs) + case imap.RemoveFlags: + err = mbox.u.c.UnlabelMessages(protonmail.LabelStarred, apiIDs) + } + case imap.DeletedFlag: + // TODO: send updates + switch op { + case imap.SetFlags, imap.AddFlags: + for _, apiID := range apiIDs { + mbox.deleted[apiID] = struct{}{} + } + case imap.RemoveFlags: + for _, apiID := range apiIDs { + delete(mbox.deleted, apiID) + } + } + } + if err != nil { + return err + } + } + + return mbox.Poll() +} + +func (mbox *mailbox) CopyMessages(uid bool, seqSet *imap.SeqSet, destName string) error { + if err := mbox.init(); err != nil { + return err + } + + apiIDs, err := mbox.fromSeqSet(uid, seqSet) + if err != nil { + return err + } + + dest := mbox.u.getMailbox(destName) + if dest == nil { + return imapbackend.ErrNoSuchMailbox + } + + if err := mbox.u.c.LabelMessages(dest.label, apiIDs); err != nil { + return err + } + return mbox.Poll() +} + +func (mbox *mailbox) MoveMessages(uid bool, seqSet *imap.SeqSet, destName string) error { + if err := mbox.init(); err != nil { + return err + } + + apiIDs, err := mbox.fromSeqSet(uid, seqSet) + if err != nil { + return err + } + + dest := mbox.u.getMailbox(destName) + if dest == nil { + return imapbackend.ErrNoSuchMailbox + } + + if err := mbox.u.c.LabelMessages(dest.label, apiIDs); err != nil { + return err + } + if err := mbox.u.c.UnlabelMessages(mbox.label, apiIDs); err != nil { + return err + } + return mbox.Poll() +} + +func (mbox *mailbox) Expunge() error { + if err := mbox.init(); err != nil { + return err + } + + apiIDs := make([]string, 0, len(mbox.deleted)) + for apiID := range mbox.deleted { + apiIDs = append(apiIDs, apiID) + } + + if err := mbox.u.c.DeleteMessages(apiIDs); err != nil { + return err + } + + return mbox.Poll() +} + +func (mbox *mailbox) Poll() error { + mbox.u.poll() + return nil +} diff --git a/imap/message.go b/imap/message.go new file mode 100644 index 0000000..48dab72 --- /dev/null +++ b/imap/message.go @@ -0,0 +1,552 @@ +package imap + +import ( + "bytes" + "errors" + "fmt" + "io" + "log" + "strings" + "time" + + "github.com/emersion/go-imap" + "github.com/emersion/go-message" + "github.com/emersion/go-message/mail" + "golang.org/x/crypto/openpgp" + + "github.com/emersion/hydroxide/protonmail" +) + +func messageID(msg *protonmail.Message) string { + if msg.ExternalID != "" { + return msg.ExternalID + } else { + return msg.ID + "@protonmail.com" + } +} + +func formatHeader(h mail.Header) string { + var b bytes.Buffer + for k, values := range h.Header { + for _, v := range values { + b.WriteString(fmt.Sprintf("%s: %s\r\n", k, v)) + } + } + return b.String() +} + +func protonmailAddressList(addresses []*mail.Address) []*protonmail.MessageAddress { + l := make([]*protonmail.MessageAddress, len(addresses)) + for i, addr := range addresses { + l[i] = &protonmail.MessageAddress{ + Name: addr.Name, + Address: addr.Address, + } + } + return l +} + +func imapAddress(addr *protonmail.MessageAddress) *imap.Address { + parts := strings.SplitN(addr.Address, "@", 2) + if len(parts) < 2 { + parts = append(parts, "") + } + + return &imap.Address{ + PersonalName: addr.Name, + MailboxName: parts[0], + HostName: parts[1], + } +} + +func imapAddressList(addresses []*protonmail.MessageAddress) []*imap.Address { + l := make([]*imap.Address, len(addresses)) + for i, addr := range addresses { + l[i] = imapAddress(addr) + } + return l +} + +func fetchEnvelope(msg *protonmail.Message) *imap.Envelope { + var replyTo []*imap.Address + if msg.ReplyTo != nil { + replyTo = []*imap.Address{imapAddress(msg.ReplyTo)} + } + + return &imap.Envelope{ + Date: time.Unix(msg.Time, 0), + Subject: msg.Subject, + From: []*imap.Address{imapAddress(msg.Sender)}, + // TODO: Sender + ReplyTo: replyTo, + To: imapAddressList(msg.ToList), + Cc: imapAddressList(msg.CCList), + Bcc: imapAddressList(msg.BCCList), + // TODO: InReplyTo + MessageId: messageID(msg), + } +} + +func hasLabel(msg *protonmail.Message, labelID string) bool { + for _, id := range msg.LabelIDs { + if labelID == id { + return true + } + } + return false +} + +// Doesn't support imap.DeletedFlag. +func fetchFlags(msg *protonmail.Message) []string { + var flags []string + if msg.IsRead != 0 { + flags = append(flags, imap.SeenFlag) + } + if msg.IsReplied != 0 || msg.IsRepliedAll != 0 { + flags = append(flags, imap.AnsweredFlag) + } + for _, label := range msg.LabelIDs { + switch label { + case protonmail.LabelStarred: + flags = append(flags, imap.FlaggedFlag) + case protonmail.LabelDraft: + flags = append(flags, imap.DraftFlag) + } + } + return flags +} + +func splitMIMEType(t string) (string, string) { + parts := strings.SplitN(t, "/", 2) + if len(parts) < 2 { + return "text", "plain" + } + return parts[0], parts[1] +} + +func (mbox *mailbox) fetchBodyStructure(msg *protonmail.Message, extended bool) (*imap.BodyStructure, error) { + if msg.NumAttachments > 0 { + var err error + msg, err = mbox.u.c.GetMessage(msg.ID) + if err != nil { + return nil, err + } + } + + inlineType, inlineSubType := splitMIMEType(msg.MIMEType) + parts := []*imap.BodyStructure{ + &imap.BodyStructure{ + MIMEType: inlineType, + MIMESubType: inlineSubType, + Encoding: "quoted-printable", + Size: uint32(len(msg.Body)), + Extended: extended, + Disposition: "inline", + }, + } + + for _, att := range msg.Attachments { + attType, attSubType := splitMIMEType(att.MIMEType) + parts = append(parts, &imap.BodyStructure{ + MIMEType: attType, + MIMESubType: attSubType, + Id: att.ContentID, + Encoding: "base64", + Size: uint32(att.Size), + Extended: extended, + Disposition: "attachment", + DispositionParams: map[string]string{"filename": att.Name}, + }) + } + + return &imap.BodyStructure{ + MIMEType: "multipart", + MIMESubType: "mixed", + // TODO: Params: map[string]string{"boundary": ...}, + // TODO: Size + Parts: parts, + Extended: extended, + }, nil +} + +func (mbox *mailbox) inlineBody(msg *protonmail.Message) (io.Reader, error) { + md, err := msg.Read(mbox.u.privateKeys, nil) + if err != nil { + return nil, err + } + + // TODO: check signature + return md.UnverifiedBody, nil +} + +func (mbox *mailbox) attachmentBody(att *protonmail.Attachment) (io.Reader, error) { + rc, err := mbox.u.c.GetAttachment(att.ID) + if err != nil { + return nil, err + } + + md, err := att.Read(rc, mbox.u.privateKeys, nil) + if err != nil { + return nil, err + } + + // TODO: check signature + return md.UnverifiedBody, nil +} + +func inlineHeader(msg *protonmail.Message) message.Header { + h := mail.NewTextHeader() + if msg.MIMEType != "" { + h.SetContentType(msg.MIMEType, nil) + } else { + log.Println("Sending an inline header without its proper MIME type") + } + h.Set("Content-Transfer-Encoding", "quoted-printable") + return h.Header +} + +func attachmentHeader(att *protonmail.Attachment) message.Header { + h := mail.NewAttachmentHeader() + h.SetContentType(att.MIMEType, nil) + h.Set("Content-Transfer-Encoding", "base64") + h.SetFilename(att.Name) + if att.ContentID != "" { + h.Set("Content-Id", att.ContentID) + } + return h.Header +} + +func mailAddress(addr *protonmail.MessageAddress) *mail.Address { + return &mail.Address{ + Name: addr.Name, + Address: addr.Address, + } +} + +func mailAddressList(addresses []*protonmail.MessageAddress) []*mail.Address { + l := make([]*mail.Address, len(addresses)) + for i, addr := range addresses { + l[i] = mailAddress(addr) + } + return l +} + +func messageHeader(msg *protonmail.Message) message.Header { + h := mail.NewHeader() + h.SetContentType("multipart/mixed", nil) + h.SetDate(time.Unix(msg.Time, 0)) + h.SetSubject(msg.Subject) + h.SetAddressList("From", []*mail.Address{mailAddress(msg.Sender)}) + if msg.ReplyTo != nil { + h.SetAddressList("Reply-To", []*mail.Address{mailAddress(msg.ReplyTo)}) + } + if len(msg.ToList) > 0 { + h.SetAddressList("To", mailAddressList(msg.ToList)) + } + if len(msg.CCList) > 0 { + h.SetAddressList("Cc", mailAddressList(msg.CCList)) + } + if len(msg.BCCList) > 0 { + h.SetAddressList("Bcc", mailAddressList(msg.BCCList)) + } + // TODO: In-Reply-To + h.Set("Message-Id", messageID(msg)) + return h.Header +} + +func (mbox *mailbox) fetchBodySection(msg *protonmail.Message, section *imap.BodySectionName) (imap.Literal, error) { + // TODO: section.Peek + + b := new(bytes.Buffer) + + if len(section.Path) == 0 { + w, err := message.CreateWriter(b, messageHeader(msg)) + if err != nil { + return nil, err + } + + if section.Specifier == imap.TextSpecifier { + b.Reset() + } + + switch section.Specifier { + case imap.EntireSpecifier, imap.TextSpecifier: + msg, err := mbox.u.c.GetMessage(msg.ID) + if err != nil { + return nil, err + } + + pw, err := w.CreatePart(inlineHeader(msg)) + if err != nil { + return nil, err + } + pr, err := mbox.inlineBody(msg) + if err != nil { + return nil, err + } + if _, err := io.Copy(pw, pr); err != nil { + return nil, err + } + pw.Close() + + for _, att := range msg.Attachments { + pw, err := w.CreatePart(attachmentHeader(att)) + if err != nil { + return nil, err + } + pr, err := mbox.attachmentBody(att) + if err != nil { + return nil, err + } + if _, err := io.Copy(pw, pr); err != nil { + return nil, err + } + pw.Close() + } + } + + w.Close() + } else { + if len(section.Path) > 1 { + return nil, errors.New("invalid body section path length") + } + + var h message.Header + var getBody func() (io.Reader, error) + if part := section.Path[0]; part == 1 { + // TODO: only fetch the message if the body is needed + // For now we fetch it in all cases because the MIME type is not included + // in the cached message, and inlineHeader needs it + msg, err := mbox.u.c.GetMessage(msg.ID) + if err != nil { + return nil, err + } + + h = inlineHeader(msg) + getBody = func() (io.Reader, error) { + return mbox.inlineBody(msg) + } + } else { + i := part - 2 + if i >= msg.NumAttachments { + return nil, errors.New("invalid attachment section path") + } + + msg, err := mbox.u.c.GetMessage(msg.ID) + if err != nil { + return nil, err + } + + att := msg.Attachments[i] + h = attachmentHeader(att) + getBody = func() (io.Reader, error) { + return mbox.attachmentBody(att) + } + } + + w, err := message.CreateWriter(b, h) + if err != nil { + return nil, err + } + + if section.Specifier == imap.TextSpecifier { + b.Reset() + } + + switch section.Specifier { + case imap.EntireSpecifier, imap.TextSpecifier: + r, err := getBody() + if err != nil { + return nil, err + } + + if _, err := io.Copy(w, r); err != nil { + return nil, err + } + } + + w.Close() + } + + var l imap.Literal = b + if section.Partial != nil { + l = bytes.NewReader(section.ExtractPartial(b.Bytes())) + } + + return l, nil +} + +func createMessage(c *protonmail.Client, u *protonmail.User, privateKeys openpgp.EntityList, r io.Reader) (*protonmail.Message, error) { + // Parse the incoming MIME message header + mr, err := mail.CreateReader(r) + if err != nil { + return nil, err + } + + subject, _ := mr.Header.Subject() + fromList, _ := mr.Header.AddressList("From") + toList, _ := mr.Header.AddressList("To") + ccList, _ := mr.Header.AddressList("Cc") + bccList, _ := mr.Header.AddressList("Bcc") + + if len(fromList) != 1 { + return nil, errors.New("the From field must contain exactly one address") + } + if len(toList) == 0 && len(ccList) == 0 && len(bccList) == 0 { + return nil, errors.New("no recipient specified") + } + + fromAddrStr := fromList[0].Address + var fromAddr *protonmail.Address + for _, addr := range u.Addresses { + if strings.EqualFold(addr.Email, fromAddrStr) { + fromAddr = addr + break + } + } + if fromAddr == nil { + return nil, errors.New("unknown sender address") + } + if len(fromAddr.Keys) == 0 { + return nil, errors.New("sender address has no private key") + } + + // TODO: get appropriate private key + encryptedPrivateKey, err := fromAddr.Keys[0].Entity() + if err != nil { + return nil, fmt.Errorf("cannot parse sender private key: %v", err) + } + + var privateKey *openpgp.Entity + for _, e := range privateKeys { + if e.PrimaryKey.KeyId == encryptedPrivateKey.PrimaryKey.KeyId { + privateKey = e + break + } + } + if privateKey == nil { + return nil, errors.New("sender address key hasn't been decrypted") + } + + msg := &protonmail.Message{ + ToList: protonmailAddressList(toList), + CCList: protonmailAddressList(ccList), + BCCList: protonmailAddressList(bccList), + Subject: subject, + Header: formatHeader(mr.Header), + AddressID: fromAddr.ID, + } + + // Create an empty draft + plaintext, err := msg.Encrypt([]*openpgp.Entity{privateKey}, privateKey) + if err != nil { + return nil, err + } + if err := plaintext.Close(); err != nil { + return nil, err + } + + // TODO: parentID from In-Reply-To + msg, err = c.CreateDraftMessage(msg, "") + if err != nil { + return nil, fmt.Errorf("cannot create draft message: %v", err) + } + + var body *bytes.Buffer + var bodyType string + + for { + p, err := mr.NextPart() + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + switch h := p.Header.(type) { + case mail.TextHeader: + t, _, err := h.ContentType() + if err != nil { + break + } + + if body != nil && t != "text/html" { + break + } + + body = &bytes.Buffer{} + bodyType = t + if _, err := io.Copy(body, p.Body); err != nil { + return nil, err + } + case mail.AttachmentHeader: + t, _, err := h.ContentType() + if err != nil { + break + } + + filename, err := h.Filename() + if err != nil { + break + } + + att := &protonmail.Attachment{ + MessageID: msg.ID, + Name: filename, + MIMEType: t, + ContentID: h.Get("Content-Id"), + // TODO: Header + } + + _, err = att.GenerateKey([]*openpgp.Entity{privateKey}) + if err != nil { + return nil, fmt.Errorf("cannot generate attachment key: %v", err) + } + + pr, pw := io.Pipe() + + go func() { + cleartext, err := att.Encrypt(pw, privateKey) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.Copy(cleartext, p.Body); err != nil { + pw.CloseWithError(err) + return + } + pw.CloseWithError(cleartext.Close()) + }() + + att, err = c.CreateAttachment(att, pr) + if err != nil { + return nil, fmt.Errorf("cannot upload attachment: %v", err) + } + + msg.Attachments = append(msg.Attachments, att) + } + } + + if body == nil { + return nil, errors.New("message doesn't contain a body part") + } + + // Encrypt the body and update the draft + msg.MIMEType = bodyType + plaintext, err = msg.Encrypt([]*openpgp.Entity{privateKey}, privateKey) + if err != nil { + return nil, err + } + if _, err := io.Copy(plaintext, body); err != nil { + return nil, err + } + if err := plaintext.Close(); err != nil { + return nil, err + } + + if _, err := c.UpdateDraftMessage(msg); err != nil { + return nil, fmt.Errorf("cannot update draft message: %v", err) + } + + return msg, nil +} diff --git a/imap/user.go b/imap/user.go new file mode 100644 index 0000000..fcb7f4b --- /dev/null +++ b/imap/user.go @@ -0,0 +1,324 @@ +package imap + +import ( + "log" + "sync" + + "golang.org/x/crypto/openpgp" + "github.com/emersion/go-imap" + imapbackend "github.com/emersion/go-imap/backend" + "github.com/emersion/go-imap-specialuse" + + "github.com/emersion/hydroxide/events" + "github.com/emersion/hydroxide/imap/database" + "github.com/emersion/hydroxide/protonmail" +) + +var systemMailboxes = []struct{ + name string + label string + flags []string +}{ + {imap.InboxName, protonmail.LabelInbox, nil}, + {"All Mail", protonmail.LabelAllMail, []string{specialuse.All}}, + {"Archive", protonmail.LabelArchive, []string{specialuse.Archive}}, + {"Drafts", protonmail.LabelDraft, []string{specialuse.Drafts}}, + {"Starred", protonmail.LabelStarred, []string{specialuse.Flagged}}, + {"Spam", protonmail.LabelSpam, []string{specialuse.Junk}}, + {"Sent", protonmail.LabelSent, []string{specialuse.Sent}}, + {"Trash", protonmail.LabelTrash, []string{specialuse.Trash}}, +} + +type user struct { + c *protonmail.Client + u *protonmail.User + privateKeys openpgp.EntityList + + db *database.User + eventsReceiver *events.Receiver + + locker sync.Mutex + mailboxes map[string]*mailbox + + done chan<- struct{} + eventSent chan struct{} +} + +func newUser(be *backend, c *protonmail.Client, u *protonmail.User, privateKeys openpgp.EntityList) (*user, error) { + uu := &user{ + c: c, + u: u, + privateKeys: privateKeys, + eventSent: make(chan struct{}), + } + + db, err := database.Open(u.Name+".db") + if err != nil { + return nil, err + } + uu.db = db + + if err := uu.initMailboxes(); err != nil { + return nil, err + } + + done := make(chan struct{}) + uu.done = done + ch := make(chan *protonmail.Event) + go uu.receiveEvents(be.updates, ch) + uu.eventsReceiver = be.eventsManager.Register(c, u.Name, ch, done) + + return uu, nil +} + +func (u *user) initMailboxes() error { + u.locker.Lock() + defer u.locker.Unlock() + + u.mailboxes = make(map[string]*mailbox) + + for _, data := range systemMailboxes { + mboxDB, err := u.db.Mailbox(data.label) + if err != nil { + return err + } + + u.mailboxes[data.label] = &mailbox{ + name: data.name, + label: data.label, + flags: data.flags, + u: u, + db: mboxDB, + deleted: make(map[string]struct{}), + } + } + + counts, err := u.c.CountMessages("") + if err != nil { + return err + } + + for _, count := range counts { + if mbox, ok := u.mailboxes[count.LabelID]; ok { + mbox.total = count.Total + mbox.unread = count.Unread + } + } + + return nil +} + +func (u *user) Username() string { + return u.u.Name +} + +func (u *user) ListMailboxes(subscribed bool) ([]imapbackend.Mailbox, error) { + u.locker.Lock() + defer u.locker.Unlock() + + list := make([]imapbackend.Mailbox, 0, len(u.mailboxes)) + for _, mbox := range u.mailboxes { + list = append(list, mbox) + } + return list, nil +} + +func (u *user) getMailboxByLabel(labelID string) *mailbox { + u.locker.Lock() + defer u.locker.Unlock() + return u.mailboxes[labelID] +} + +func (u *user) getMailbox(name string) *mailbox { + u.locker.Lock() + defer u.locker.Unlock() + + for _, mbox := range u.mailboxes { + if mbox.name == name { + return mbox + } + } + return nil +} + +func (u *user) GetMailbox(name string) (imapbackend.Mailbox, error) { + mbox := u.getMailbox(name) + if mbox == nil { + return nil, imapbackend.ErrNoSuchMailbox + } + return mbox, nil +} + +func (u *user) CreateMailbox(name string) error { + return errNotYetImplemented // TODO +} + +func (u *user) DeleteMailbox(name string) error { + return errNotYetImplemented // TODO +} + +func (u *user) RenameMailbox(existingName, newName string) error { + return errNotYetImplemented // TODO +} + +func (u *user) Logout() error { + close(u.done) + + if err := u.db.Close(); err != nil { + return err + } + + u.c = nil + u.u = nil + u.privateKeys = nil + return nil +} + +func (u *user) poll() { + go u.eventsReceiver.Poll() + <-u.eventSent +} + +func (u *user) receiveEvents(updates chan<- interface{}, events <-chan *protonmail.Event) { + for event := range events { + var eventUpdates []interface{} + + if event.Refresh&protonmail.EventRefreshMail != 0 { + log.Println("Reinitializing the whole IMAP database") + + u.locker.Lock() + for _, mbox := range u.mailboxes { + if err := mbox.reset(); err != nil { + log.Printf("cannot reset mailbox %s: %v", mbox.name, err) + } + } + u.locker.Unlock() + + if err := u.db.ResetMessages(); err != nil { + log.Printf("cannot reset user: %v", err) + } + + if err := u.initMailboxes(); err != nil { + log.Printf("cannot reinitialize mailboxes: %v", err) + } + } else { + for _, eventMessage := range event.Messages { + switch eventMessage.Action { + case protonmail.EventCreate: + log.Println("Received create event for message", eventMessage.ID) + seqNums, err := u.db.CreateMessage(eventMessage.Created) + if err != nil { + log.Printf("cannot handle create event for message %s: cannot create message in local DB: %v", eventMessage.ID, err) + break + } + + // TODO: what if the message was already in the local DB? + for labelID, seqNum := range seqNums { + if mbox := u.getMailboxByLabel(labelID); mbox != nil { + update := new(imapbackend.MailboxUpdate) + update.Username = u.u.Name + update.Mailbox = mbox.name + update.MailboxStatus = imap.NewMailboxStatus(mbox.name, []imap.StatusItem{imap.StatusMessages}) + update.MailboxStatus.Messages = seqNum + eventUpdates = append(eventUpdates, update) + } + } + case protonmail.EventUpdate, protonmail.EventUpdateFlags: + log.Println("Received update event for message", eventMessage.ID) + createdSeqNums, deletedSeqNums, err := u.db.UpdateMessage(eventMessage.ID, eventMessage.Updated) + if err != nil { + log.Printf("cannot handle update event for message %s: cannot update message in local DB: %v", eventMessage.ID, err) + break + } + + for labelID, seqNum := range createdSeqNums { + if mbox := u.getMailboxByLabel(labelID); mbox != nil { + update := new(imapbackend.MailboxUpdate) + update.Username = u.u.Name + update.Mailbox = mbox.name + update.MailboxStatus = imap.NewMailboxStatus(mbox.name, []imap.StatusItem{imap.StatusMessages}) + update.MailboxStatus.Messages = seqNum + eventUpdates = append(eventUpdates, update) + } + } + for labelID, seqNum := range deletedSeqNums { + if mbox := u.getMailboxByLabel(labelID); mbox != nil { + update := new(imapbackend.ExpungeUpdate) + update.Username = u.u.Name + update.Mailbox = mbox.name + update.SeqNum = seqNum + eventUpdates = append(eventUpdates, update) + } + } + + // Send message updates + msg, err := u.db.Message(eventMessage.ID) + if err != nil { + log.Printf("cannot handle update event for message %s: cannot get updated message from local DB: %v", eventMessage.ID, err) + break + } + for _, labelID := range msg.LabelIDs { + if _, created := createdSeqNums[labelID]; created { + // This message has been added to the label's mailbox + // No need to send a message update + continue + } + + if mbox := u.getMailboxByLabel(labelID); mbox != nil { + seqNum, _, err := mbox.db.FromApiID(eventMessage.ID) + if err != nil { + log.Printf("cannot handle update event for message %s: cannot get message sequence number in %s: %v", eventMessage.ID, mbox.name, err) + continue + } + + update := new(imapbackend.MessageUpdate) + update.Username = u.u.Name + update.Mailbox = mbox.name + update.Message = imap.NewMessage(seqNum, []imap.FetchItem{imap.FetchFlags}) + update.Message.Flags = fetchFlags(msg) + eventUpdates = append(eventUpdates, update) + } + } + case protonmail.EventDelete: + log.Println("Received delete event for message", eventMessage.ID) + seqNums, err := u.db.DeleteMessage(eventMessage.ID) + if err != nil { + log.Printf("cannot handle delete event for message %s: cannot delete message from local DB: %v", eventMessage.ID, err) + break + } + + for labelID, seqNum := range seqNums { + if mbox := u.getMailboxByLabel(labelID); mbox != nil { + update := new(imapbackend.ExpungeUpdate) + update.Username = u.u.Name + update.Mailbox = mbox.name + update.SeqNum = seqNum + eventUpdates = append(eventUpdates, update) + } + } + } + } + + u.locker.Lock() + for _, count := range event.MessageCounts { + if mbox, ok := u.mailboxes[count.LabelID]; ok { + mbox.total = count.Total + mbox.unread = count.Unread + } + } + u.locker.Unlock() + } + + done := imapbackend.WaitUpdates(eventUpdates...) + for _, update := range eventUpdates { + updates <- update + } + go func() { + <-done + select { + case u.eventSent <- struct{}{}: + default: + } + }() + } +} diff --git a/protonmail/attachments.go b/protonmail/attachments.go index 6427c16..509b579 100644 --- a/protonmail/attachments.go +++ b/protonmail/attachments.go @@ -91,8 +91,22 @@ func (att *Attachment) Encrypt(ciphertext io.Writer, signed *openpgp.Entity) (cl return symetricallyEncrypt(ciphertext, att.unencryptedKey, nil, hints, config) } +func (att *Attachment) Read(ciphertext io.Reader, keyring openpgp.KeyRing, prompt openpgp.PromptFunction) (*openpgp.MessageDetails, error) { + if len(att.KeyPackets) == 0 { + return &openpgp.MessageDetails{ + IsEncrypted: false, + IsSigned: false, + UnverifiedBody: ciphertext, + }, nil + } else { + kpr := base64.NewDecoder(base64.StdEncoding, strings.NewReader(att.KeyPackets)) + r := io.MultiReader(kpr, ciphertext) + return openpgp.ReadMessage(r, keyring, prompt, nil) + } +} + // GetAttachment downloads an attachment's payload. The returned io.ReadCloser -// may be encrypted. +// may be encrypted, use Attachment.Read to decrypt it. func (c *Client) GetAttachment(id string) (io.ReadCloser, error) { req, err := c.newRequest(http.MethodGet, "/attachments/"+id, nil) if err != nil { diff --git a/protonmail/events.go b/protonmail/events.go index 7585a8e..ba41e22 100644 --- a/protonmail/events.go +++ b/protonmail/events.go @@ -1,13 +1,21 @@ package protonmail import ( + "encoding/json" "net/http" ) +type EventRefresh int + +const ( + EventRefreshMail EventRefresh = 1 << iota + EventRefreshContacts +) + type Event struct { ID string `json:"EventID"` - Refresh int - //Messages + Refresh EventRefresh + Messages []*EventMessage Contacts []*EventContact //ContactEmails //Labels @@ -15,7 +23,7 @@ type Event struct { //Members //Domains //Organization - //MessageCounts + MessageCounts []*MessageCount //ConversationCounts //UsedSpace Notices []string @@ -27,8 +35,125 @@ const ( EventDelete EventAction = iota EventCreate EventUpdate + + // For messages + EventUpdateFlags ) +type EventMessage struct { + ID string + Action EventAction + + // Only populated for EventCreate + Created *Message + // Only populated for EventUpdate or EventUpdateFlags + Updated *EventMessageUpdate +} + +type EventMessageUpdate struct { + IsRead *int + Type *MessageType + Time int64 + IsReplied *int + IsRepliedAll *int + IsForwarded *int + + // Only populated for EventUpdateFlags + LabelIDs []string + LabelIDsAdded []string + LabelIDsRemoved []string +} + +func buildLabelsSet(labelIDs []string) map[string]struct{} { + set := make(map[string]struct{}, len(labelIDs)) + for _, labelID := range labelIDs { + set[labelID] = struct{}{} + } + return set +} + +func (update *EventMessageUpdate) DiffLabelIDs(current []string) (added, removed []string) { + if update.LabelIDsAdded != nil && update.LabelIDsRemoved != nil { + return update.LabelIDsAdded, update.LabelIDsRemoved + } + if update.LabelIDs == nil { + return + } + + currentSet := buildLabelsSet(current) + updateSet := buildLabelsSet(update.LabelIDs) + for labelID := range currentSet { + if _, ok := updateSet[labelID]; !ok { + removed = append(removed, labelID) + } + } + for labelID := range updateSet { + if _, ok := currentSet[labelID]; !ok { + added = append(added, labelID) + } + } + return +} + +func (update *EventMessageUpdate) Patch(msg *Message) { + msg.Time = update.Time + if update.IsRead != nil { + msg.IsRead = *update.IsRead + } + if update.Type != nil { + msg.Type = *update.Type + } + if update.IsReplied != nil { + msg.IsReplied = *update.IsReplied + } + if update.IsRepliedAll != nil { + msg.IsRepliedAll = *update.IsRepliedAll + } + if update.IsForwarded != nil { + msg.IsForwarded = *update.IsForwarded + } + + if update.LabelIDs != nil { + msg.LabelIDs = update.LabelIDs + } else if update.LabelIDsAdded != nil && update.LabelIDsRemoved != nil { + set := buildLabelsSet(msg.LabelIDs) + for _, labelID := range update.LabelIDsAdded { + set[labelID] = struct{}{} + } + for _, labelID := range update.LabelIDsRemoved { + delete(set, labelID) + } + msg.LabelIDs = make([]string, 0, len(set)) + for labelID := range set { + msg.LabelIDs = append(msg.LabelIDs, labelID) + } + } +} + +type rawEventMessage struct { + ID string + Action EventAction + Message json.RawMessage `json:",omitempty"` +} + +func (em *EventMessage) UnmarshalJSON(b []byte) error { + var raw rawEventMessage + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + em.ID = raw.ID + em.Action = raw.Action + switch raw.Action { + case EventCreate: + em.Created = new(Message) + return json.Unmarshal(raw.Message, em.Created) + case EventUpdate, EventUpdateFlags: + em.Updated = new(EventMessageUpdate) + return json.Unmarshal(raw.Message, em.Updated) + } + return nil +} + type EventContact struct { ID string Action EventAction diff --git a/protonmail/labels.go b/protonmail/labels.go new file mode 100644 index 0000000..b019981 --- /dev/null +++ b/protonmail/labels.go @@ -0,0 +1,14 @@ +package protonmail + +const ( + LabelInbox = "0" + LabelAllDraft = "1" + LabelAllSent = "2" + LabelTrash = "3" + LabelSpam = "4" + LabelAllMail = "5" + LabelArchive = "6" + LabelSent = "7" + LabelDraft = "8" + LabelStarred = "10" +) diff --git a/protonmail/messages.go b/protonmail/messages.go index fc9766b..8b27642 100644 --- a/protonmail/messages.go +++ b/protonmail/messages.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/http" + "net/url" "strconv" "strings" @@ -55,6 +56,7 @@ type Message struct { ToList []*MessageAddress Time int64 Size int64 + NumAttachments int IsEncrypted MessageEncryption ExpirationTime int64 IsReplied int @@ -132,6 +134,103 @@ func (msg *Message) Encrypt(to []*openpgp.Entity, signed *openpgp.Entity) (plain }, nil } +type MessageFilter struct { + Page int + PageSize int + Limit int + + Label string + Sort string + Asc bool + Begin int64 + End int64 + Keyword string + To string + From string + Subject string + Attachments *bool + Starred *bool + Unread *bool + Conversation string + Address string + ID []string + ExternalID string +} + +func (c *Client) ListMessages(filter *MessageFilter) (total int, messages []*Message, err error) { + v := url.Values{} + if filter.Page != 0 { + v.Set("Page", strconv.Itoa(filter.Page)) + } + if filter.PageSize != 0 { + v.Set("PageSize", strconv.Itoa(filter.PageSize)) + } + if filter.Limit != 0 { + v.Set("Limit", strconv.Itoa(filter.Limit)) + } + if filter.Label != "" { + v.Set("Label", filter.Label) + } + if filter.Sort != "" { + v.Set("Sort", filter.Sort) + } + if filter.Asc { + v.Set("Desc", "0") + } + if filter.Conversation != "" { + v.Set("Conversation", filter.Conversation) + } + if filter.Address != "" { + v.Set("Address", filter.Address) + } + if filter.ExternalID != "" { + v.Set("ExternalID", filter.ExternalID) + } + + req, err := c.newRequest(http.MethodGet, "/messages?"+v.Encode(), nil) + if err != nil { + return 0, nil, err + } + + var respData struct { + resp + Total int + Messages []*Message + } + if err := c.doJSON(req, &respData); err != nil { + return 0, nil, err + } + + return respData.Total, respData.Messages, nil +} + +type MessageCount struct { + LabelID string + Total int + Unread int +} + +func (c *Client) CountMessages(address string) ([]*MessageCount, error) { + v := url.Values{} + if address != "" { + v.Set("Address", address) + } + req, err := c.newRequest(http.MethodGet, "/messages/count?"+v.Encode(), nil) + if err != nil { + return nil, err + } + + var respData struct { + resp + Counts []*MessageCount + } + if err := c.doJSON(req, &respData); err != nil { + return nil, err + } + + return respData.Counts, nil +} + func (c *Client) GetMessage(id string) (*Message, error) { req, err := c.newRequest(http.MethodGet, "/messages/"+id, nil) if err != nil { @@ -193,6 +292,63 @@ func (c *Client) UpdateDraftMessage(msg *Message) (*Message, error) { return respData.Message, nil } +func (c *Client) doMessages(action string, ids []string) error { + reqData := struct { + IDs []string + }{ids} + req, err := c.newJSONRequest(http.MethodPut, "/messages/"+action, &reqData) + if err != nil { + return err + } + + // TODO: the response contains one response per message + return c.doJSON(req, nil) +} + +func (c *Client) MarkMessagesRead(ids []string) error { + return c.doMessages("read", ids) +} + +func (c *Client) MarkMessagesUnread(ids []string) error { + return c.doMessages("unread", ids) +} + +func (c *Client) DeleteMessages(ids []string) error { + return c.doMessages("delete", ids) +} + +func (c *Client) UndeleteMessages(ids []string) error { + return c.doMessages("undelete", ids) +} + +func (c *Client) LabelMessages(labelID string, ids []string) error { + reqData := struct { + LabelID string + IDs []string + }{labelID, ids} + req, err := c.newJSONRequest(http.MethodPut, "/messages/label", &reqData) + if err != nil { + return err + } + + // TODO: the response contains one response per message + return c.doJSON(req, nil) +} + +func (c *Client) UnlabelMessages(labelID string, ids []string) error { + reqData := struct { + LabelID string + IDs []string + }{labelID, ids} + req, err := c.newJSONRequest(http.MethodPut, "/messages/unlabel", &reqData) + if err != nil { + return err + } + + // TODO: the response contains one response per message + return c.doJSON(req, nil) +} + type MessageKeyPacket struct { ID string KeyPackets string diff --git a/smtp/smtp.go b/smtp/smtp.go index 4bef59a..d794903 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -311,6 +311,7 @@ func (u *user) Send(from string, to []string, r io.Reader) error { func (u *user) Logout() error { u.c = nil + u.u = nil u.privateKeys = nil return nil }