diff --git a/imap/mailbox.go b/imap/mailbox.go index df9cc03..a319ded 100644 --- a/imap/mailbox.go +++ b/imap/mailbox.go @@ -24,9 +24,9 @@ type mailbox struct { u *user db *database.Mailbox - initialized bool - initializedLock sync.Mutex + sync.Mutex // protects everything below + initialized bool total, unread int deleted map[string]struct{} } @@ -60,20 +60,23 @@ func (mbox *mailbox) Info() (*imap.MailboxInfo, error) { } func (mbox *mailbox) Status(items []imap.StatusItem) (*imap.MailboxStatus, error) { - mbox.u.locker.Lock() + mbox.u.Lock() flags := []string{imap.SeenFlag, imap.DeletedFlag} permFlags := []string{imap.SeenFlag} for _, flag := range mbox.u.flags { flags = append(flags, flag) permFlags = append(permFlags, flag) } - mbox.u.locker.Unlock() + mbox.u.Unlock() status := imap.NewMailboxStatus(mbox.name, items) status.Flags = flags status.PermanentFlags = permFlags status.UnseenSeqNum = 0 // TODO + mbox.Lock() + defer mbox.Unlock() + for _, name := range items { switch name { case imap.StatusMessages: @@ -141,13 +144,12 @@ func (mbox *mailbox) sync() error { } log.Printf("Synchronizing mailbox %v: done.", mbox.name) - return nil } func (mbox *mailbox) init() error { - mbox.initializedLock.Lock() - defer mbox.initializedLock.Unlock() + mbox.Lock() + defer mbox.Unlock() if mbox.initialized { return nil @@ -163,11 +165,10 @@ func (mbox *mailbox) init() error { } func (mbox *mailbox) reset() error { - mbox.initializedLock.Lock() - defer mbox.initializedLock.Unlock() + mbox.Lock() + defer mbox.Unlock() mbox.initialized = false - return mbox.db.Reset() } @@ -179,14 +180,21 @@ func (mbox *mailbox) fetchFlags(msg *protonmail.Message) []string { if msg.IsReplied != 0 || msg.IsRepliedAll != 0 { flags = append(flags, imap.AnsweredFlag) } + + mbox.Lock() if _, ok := mbox.deleted[msg.ID]; ok { flags = append(flags, imap.DeletedFlag) } + mbox.Unlock() + + mbox.u.Lock() for _, label := range msg.LabelIDs { if flag, ok := mbox.u.flags[label]; ok { flags = append(flags, flag) } } + mbox.u.Unlock() + return flags } @@ -301,7 +309,7 @@ func (mbox *mailbox) SearchMessages(isUID bool, c *imap.SearchCriteria) ([]uint3 // TODO: c.Not, c.Or if c.Not != nil || c.Or != nil { - return nil, errors.New("search queries with NOT or OR clauses or not yet implemented") + return nil, errors.New("search queries with NOT or OR clauses are not yet implemented") } var results []uint32 @@ -457,6 +465,7 @@ func (mbox *mailbox) UpdateMessagesFlags(uid bool, seqSet *imap.SeqSet, op imap. } case imap.DeletedFlag: // TODO: send updates + mbox.Lock() switch op { case imap.SetFlags, imap.AddFlags: for _, apiID := range apiIDs { @@ -467,6 +476,7 @@ func (mbox *mailbox) UpdateMessagesFlags(uid bool, seqSet *imap.SeqSet, op imap. delete(mbox.deleted, apiID) } } + mbox.Unlock() case imap.DraftFlag: // No-op default: @@ -541,9 +551,11 @@ func (mbox *mailbox) Expunge() error { } apiIDs := make([]string, 0, len(mbox.deleted)) + mbox.Lock() for apiID := range mbox.deleted { apiIDs = append(apiIDs, apiID) } + mbox.Unlock() if err := mbox.u.c.DeleteMessages(apiIDs); err != nil { return err diff --git a/imap/user.go b/imap/user.go index 10eb597..1cfed55 100644 --- a/imap/user.go +++ b/imap/user.go @@ -44,22 +44,25 @@ type user struct { u *protonmail.User privateKeys openpgp.EntityList addrs []*protonmail.Address - numClients int db *database.User eventsReceiver *events.Receiver - locker sync.Mutex - mailboxes map[string]*mailbox // indexed by label ID - flags map[string]string // indexed by label ID - done chan<- struct{} eventSent chan struct{} + + sync.Mutex // protects everything below + + numClients int + mailboxes map[string]*mailbox // indexed by label ID + flags map[string]string // indexed by label ID } func getUser(be *backend, username string, c *protonmail.Client, privateKeys openpgp.EntityList) (*user, error) { if u, ok := be.users[username]; ok { + u.Lock() u.numClients++ + u.Unlock() return u, nil } else { pu, err := c.GetCurrentUser() @@ -141,8 +144,8 @@ func labelNameToFlag(s string) string { } func (u *user) initMailboxes() error { - u.locker.Lock() - defer u.locker.Unlock() + u.Lock() + defer u.Unlock() u.mailboxes = make(map[string]*mailbox) for _, data := range systemMailboxes { @@ -202,8 +205,8 @@ func (u *user) Username() string { } func (u *user) ListMailboxes(subscribed bool) ([]imapbackend.Mailbox, error) { - u.locker.Lock() - defer u.locker.Unlock() + u.Lock() + defer u.Unlock() list := make([]imapbackend.Mailbox, 0, len(u.mailboxes)) for _, mbox := range u.mailboxes { @@ -213,14 +216,14 @@ func (u *user) ListMailboxes(subscribed bool) ([]imapbackend.Mailbox, error) { } func (u *user) getMailboxByLabel(labelID string) *mailbox { - u.locker.Lock() - defer u.locker.Unlock() + u.Lock() + defer u.Unlock() return u.mailboxes[labelID] } func (u *user) getMailbox(name string) *mailbox { - u.locker.Lock() - defer u.locker.Unlock() + u.Lock() + defer u.Unlock() for _, mbox := range u.mailboxes { if mbox.name == name { @@ -239,8 +242,8 @@ func (u *user) GetMailbox(name string) (imapbackend.Mailbox, error) { } func (u *user) getFlag(name string) string { - u.locker.Lock() - defer u.locker.Unlock() + u.Lock() + defer u.Unlock() for label, flag := range u.flags { if flag == name { @@ -263,6 +266,9 @@ func (u *user) RenameMailbox(existingName, newName string) error { } func (u *user) Logout() error { + u.Lock() + defer u.Unlock() + if u.numClients <= 0 { panic("unreachable") } @@ -298,13 +304,13 @@ func (u *user) receiveEvents(updates chan<- imapbackend.Update, events <-chan *p if event.Refresh&protonmail.EventRefreshMail != 0 { log.Println("Reinitializing the whole IMAP database") - u.locker.Lock() + u.Lock() for _, mbox := range u.mailboxes { if err := mbox.reset(); err != nil { log.Printf("cannot reset mailbox %s: %v", mbox.name, err) } } - u.locker.Unlock() + u.Unlock() if err := u.db.ResetMessages(); err != nil { log.Printf("cannot reset user: %v", err) @@ -406,14 +412,14 @@ func (u *user) receiveEvents(updates chan<- imapbackend.Update, events <-chan *p } } - u.locker.Lock() + u.Lock() for _, count := range event.MessageCounts { if mbox, ok := u.mailboxes[count.LabelID]; ok { mbox.total = count.Total mbox.unread = count.Unread } } - u.locker.Unlock() + u.Unlock() } for _, update := range eventUpdates {