From f968ad6b9deed64411db08c1cefeb7252043da22 Mon Sep 17 00:00:00 2001 From: emersion Date: Thu, 11 Jan 2018 14:40:05 +0100 Subject: [PATCH] imap: update local DB from events --- imap/database/mailbox.go | 57 ++++++------ imap/database/user.go | 186 ++++++++++++++++++++++++++++++++------- imap/user.go | 42 ++++++--- protonmail/events.go | 115 +++++++++++++++++++++++- 4 files changed, 326 insertions(+), 74 deletions(-) diff --git a/imap/database/mailbox.go b/imap/database/mailbox.go index 5040af9..bd1bba1 100644 --- a/imap/database/mailbox.go +++ b/imap/database/mailbox.go @@ -20,8 +20,33 @@ func unserializeUID(b []byte) uint32 { return binary.BigEndian.Uint32(b) } +func mailboxCreateMessage(b *bolt.Bucket, apiID string) error { + want := []byte(apiID) + c := b.Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if bytes.Equal(v, want) { + return nil + } + } + + id, _ := b.NextSequence() + uid := uint32(id) + return b.Put(serializeUID(uid), want) +} + +func mailboxDeleteMessage(b *bolt.Bucket, apiID string) error { + want := []byte(apiID) + c := b.Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if bytes.Equal(v, want) { + return b.Delete(k) + } + } + return nil +} + type Mailbox struct { - name string + labelID string u *User } @@ -30,7 +55,7 @@ func (mbox *Mailbox) bucket(tx *bolt.Tx) (*bolt.Bucket, error) { if b == nil { return nil, errors.New("cannot find mailboxes bucket") } - b = b.Bucket([]byte(mbox.name)) + b = b.Bucket([]byte(mbox.labelID)) if b == nil { return nil, errors.New("cannot find mailbox bucket") } @@ -38,40 +63,20 @@ func (mbox *Mailbox) bucket(tx *bolt.Tx) (*bolt.Bucket, error) { } func (mbox *Mailbox) Sync(messages []*protonmail.Message) error { - err := mbox.u.db.Update(func(tx *bolt.Tx) error { + return mbox.u.db.Update(func(tx *bolt.Tx) error { b, err := mbox.bucket(tx) if err != nil { return err } for _, msg := range messages { - want := []byte(msg.ID) - c := b.Cursor() - found := false - for k, v := c.First(); k != nil; k, v = c.Next() { - if bytes.Equal(v, want) { - found = true - break - } - } - if found { - continue - } - - id, _ := b.NextSequence() - uid := uint32(id) - if err := b.Put(serializeUID(uid), want); err != nil { + if err := mailboxCreateMessage(b, msg.ID); err != nil { return err } } - return nil + return userSync(tx, messages) }) - if err != nil { - return err - } - - return mbox.u.sync(messages) } func (mbox *Mailbox) UidNext() (uint32, error) { @@ -181,7 +186,7 @@ func (mbox *Mailbox) Reset() error { if b == nil { return errors.New("cannot find mailboxes bucket") } - k := []byte(mbox.name) + k := []byte(mbox.labelID) if err := b.DeleteBucket(k); err != nil { return err } diff --git a/imap/database/user.go b/imap/database/user.go index 43b419d..66fd508 100644 --- a/imap/database/user.go +++ b/imap/database/user.go @@ -16,46 +16,60 @@ var ( messagesBucket = []byte("messages") ) +func userMessage(b *bolt.Bucket, apiID string) (*protonmail.Message, error) { + k := []byte(apiID) + v := b.Get(k) + if v == nil { + return nil, ErrNotFound + } + + msg := &protonmail.Message{} + err := json.Unmarshal(v, msg) + return msg, err +} + +func userCreateMessage(b *bolt.Bucket, msg *protonmail.Message) error { + k := []byte(msg.ID) + v, err := json.Marshal(msg) + if err != nil { + return err + } + return b.Put(k, v) +} + +func userSync(tx *bolt.Tx, messages []*protonmail.Message) error { + b, err := tx.CreateBucketIfNotExists(messagesBucket) + if err != nil { + return err + } + + for _, msg := range messages { + if err := userCreateMessage(b, msg); err != nil { + return err + } + } + + return nil +} + type User struct { db *bolt.DB } -func (u *User) Mailbox(name string) (*Mailbox, error) { +func (u *User) Mailbox(labelID string) (*Mailbox, error) { err := u.db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucketIfNotExists(mailboxesBucket) if err != nil { return err } - _, err = b.CreateBucketIfNotExists([]byte(name)) + _, err = b.CreateBucketIfNotExists([]byte(labelID)) return err }) if err != nil { return nil, err } - return &Mailbox{name, u}, nil -} - -func (u *User) sync(messages []*protonmail.Message) error { - return u.db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucketIfNotExists(messagesBucket) - if err != nil { - return err - } - - for _, msg := range messages { - k := []byte(msg.ID) - v, err := json.Marshal(msg) - if err != nil { - return err - } - if err := b.Put(k, v); err != nil { - return err - } - } - - return nil - }) + return &Mailbox{labelID, u}, nil } func (u *User) Message(apiID string) (*protonmail.Message, error) { @@ -66,14 +80,9 @@ func (u *User) Message(apiID string) (*protonmail.Message, error) { return ErrNotFound } - k := []byte(apiID) - v := b.Get(k) - if v == nil { - return ErrNotFound - } - - msg = &protonmail.Message{} - return json.Unmarshal(v, msg) + var err error + msg, err = userMessage(b, apiID) + return err }) return msg, err } @@ -84,6 +93,117 @@ func (u *User) ResetMessages() error { }) } +func (u *User) CreateMessage(msg *protonmail.Message) error { + return u.db.Update(func(tx *bolt.Tx) error { + messages, err := tx.CreateBucketIfNotExists(messagesBucket) + if err != nil { + return err + } + + if err := userCreateMessage(messages, msg); err != nil { + return err + } + + mailboxes, err := tx.CreateBucketIfNotExists(mailboxesBucket) + if err != nil { + return err + } + for _, labelID := range msg.LabelIDs { + mbox, err := mailboxes.CreateBucketIfNotExists([]byte(labelID)) + if err != nil { + return err + } + + if err := mailboxCreateMessage(mbox, msg.ID); err != nil { + return err + } + } + + return nil + }) +} + +func (u *User) UpdateMessage(update *protonmail.EventMessageUpdate) error { + return u.db.Update(func(tx *bolt.Tx) error { + messages := tx.Bucket(messagesBucket) + if messages == nil { + return errors.New("cannot update message in local DB: messages bucket doesn't exist") + } + + msg, err := userMessage(messages, update.ID) + if err != nil { + return err + } + + addedLabels, removedLabels := update.DiffLabelIDs(msg.LabelIDs) + + mailboxes, err := tx.CreateBucketIfNotExists(mailboxesBucket) + if err != nil { + return err + } + for _, labelID := range addedLabels { + mbox, err := mailboxes.CreateBucketIfNotExists([]byte(labelID)) + if err != nil { + return err + } + + if err := mailboxCreateMessage(mbox, update.ID); err != nil { + return err + } + } + for _, labelID := range removedLabels { + mbox := mailboxes.Bucket([]byte(labelID)) + if mbox == nil { + continue + } + + if err := mailboxDeleteMessage(mbox, update.ID); err != nil { + return err + } + } + + update.Patch(msg) + return userCreateMessage(messages, msg) + }) +} + +func (u *User) DeleteMessage(apiID string) error { + return u.db.Update(func(tx *bolt.Tx) error { + messages:= tx.Bucket(messagesBucket) + if messages == nil { + return nil + } + + msg, err := userMessage(messages, apiID) + if err == ErrNotFound { + return nil + } else if err != nil { + return err + } + + if err := messages.Delete([]byte(apiID)); err != nil { + return err + } + + mailboxes := tx.Bucket(mailboxesBucket) + if mailboxes == nil { + return nil + } + for _, labelID := range msg.LabelIDs { + mbox := mailboxes.Bucket([]byte(labelID)) + if mbox == nil { + continue + } + + if err := mailboxDeleteMessage(mbox, msg.ID); err != nil { + return err + } + } + + return nil + }) +} + func (u *User) Close() error { return u.db.Close() } diff --git a/imap/user.go b/imap/user.go index 458f925..7ee11ae 100644 --- a/imap/user.go +++ b/imap/user.go @@ -168,28 +168,42 @@ func (u *user) receiveEvents(events <-chan *protonmail.Event) { log.Printf("cannot reinitialize mailboxes: %v", err) } } else { + for _, eventMessage := range event.Messages { + switch eventMessage.Action { + case protonmail.EventCreate: + if err := u.db.CreateMessage(eventMessage.Created); err != nil { + log.Printf("cannot handle create event for message %s: cannot create message in local DB: %v", eventMessage.ID, err) + break + } + + // TODO: send updates + case protonmail.EventUpdate: + // No-op + case protonmail.EventUpdateFlags: + if err := u.db.UpdateMessage(eventMessage.Updated); err != nil { + log.Printf("cannot handle update event for message %s: cannot update message in local DB: %v", eventMessage.ID, err) + break + } + + // TODO: send updates + case protonmail.EventDelete: + if err := u.db.DeleteMessage(eventMessage.ID); err != nil { + log.Printf("cannot handle delete event for message %s: cannot delete message from local DB: %v", eventMessage.ID, err) + break + } + + // TODO: send updates + } + } + u.locker.Lock() for _, count := range event.MessageCounts { if mbox, ok := u.mailboxes[count.LabelID]; ok { mbox.total = count.Total mbox.unread = count.Unread - // TODO: send update } } u.locker.Unlock() - - for _, eventMessage := range event.Messages { - switch eventMessage.Action { - case protonmail.EventCreate: - // TODO - case protonmail.EventUpdate: - // TODO - case protonmail.EventUpdateFlags: - // TODO - case protonmail.EventDelete: - // TODO - } - } } } } diff --git a/protonmail/events.go b/protonmail/events.go index d7576ed..7696bb9 100644 --- a/protonmail/events.go +++ b/protonmail/events.go @@ -1,6 +1,7 @@ package protonmail import ( + "encoding/json" "net/http" ) @@ -42,7 +43,119 @@ const ( type EventMessage struct { ID string Action EventAction - Message *Message + + // Only populated for EventCreate + Created *Message + // Only populated for EventUpdate or EventUpdateFlags + Updated *EventMessageUpdate +} + +type EventMessageUpdate struct { + ID string + + IsRead *int + Type *MessageType + Time int64 + IsReplied *int + IsRepliedAll *int + IsForwarded *int + + // Only populated for EventUpdateFlags + LabelIDs []string + LabelIDsAdded []string + LabelIDsRemoved []string +} + +func buildLabelsSet(labelIDs []string) map[string]struct{} { + set := make(map[string]struct{}, len(labelIDs)) + for _, labelID := range labelIDs { + set[labelID] = struct{}{} + } + return set +} + +func (update *EventMessageUpdate) DiffLabelIDs(current []string) (added, removed []string) { + if update.LabelIDsAdded != nil && update.LabelIDsRemoved != nil { + return update.LabelIDsAdded, update.LabelIDsRemoved + } + if update.LabelIDs == nil { + return + } + + currentSet := buildLabelsSet(current) + updateSet := buildLabelsSet(update.LabelIDs) + for labelID := range currentSet { + if _, ok := updateSet[labelID]; !ok { + removed = append(removed, labelID) + } + } + for labelID := range updateSet { + if _, ok := currentSet[labelID]; !ok { + added = append(added, labelID) + } + } + return +} + +func (update *EventMessageUpdate) Patch(msg *Message) { + if update.ID != msg.ID { + return + } + + msg.Time = update.Time + if update.IsRead != nil { + msg.IsRead = *update.IsRead + } + if update.Type != nil { + msg.Type = *update.Type + } + if update.IsReplied != nil { + msg.IsReplied = *update.IsReplied + } + if update.IsRepliedAll != nil { + msg.IsRepliedAll = *update.IsRepliedAll + } + if update.IsForwarded != nil { + msg.IsForwarded = *update.IsForwarded + } + + if update.LabelIDs != nil { + msg.LabelIDs = update.LabelIDs + } else if update.LabelIDsAdded != nil && update.LabelIDsRemoved != nil { + set := buildLabelsSet(msg.LabelIDs) + for _, labelID := range update.LabelIDsAdded { + set[labelID] = struct{}{} + } + for _, labelID := range update.LabelIDsRemoved { + delete(set, labelID) + } + msg.LabelIDs = make([]string, 0, len(set)) + for labelID := range set { + msg.LabelIDs = append(msg.LabelIDs, labelID) + } + } +} + +type rawEventMessage struct { + ID string + Action EventAction + Message json.RawMessage `json:",omitempty"` +} + +func (em *EventMessage) UnmarshalJSON(b []byte) error { + var raw rawEventMessage + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + em.ID = raw.ID + em.Action = raw.Action + switch raw.Action { + case EventCreate: + return json.Unmarshal(raw.Message, em.Created) + case EventUpdate, EventUpdateFlags: + return json.Unmarshal(raw.Message, em.Updated) + } + return nil } type EventContact struct {