imap: fix some race conditions

This commit is contained in:
Simon Ser 2020-02-29 11:59:15 +01:00
parent 8fef87f17f
commit cbcde22b5b
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 48 additions and 30 deletions

View File

@ -24,9 +24,9 @@ type mailbox struct {
u *user u *user
db *database.Mailbox db *database.Mailbox
initialized bool sync.Mutex // protects everything below
initializedLock sync.Mutex
initialized bool
total, unread int total, unread int
deleted map[string]struct{} deleted map[string]struct{}
} }
@ -60,20 +60,23 @@ func (mbox *mailbox) Info() (*imap.MailboxInfo, error) {
} }
func (mbox *mailbox) Status(items []imap.StatusItem) (*imap.MailboxStatus, error) { func (mbox *mailbox) Status(items []imap.StatusItem) (*imap.MailboxStatus, error) {
mbox.u.locker.Lock() mbox.u.Lock()
flags := []string{imap.SeenFlag, imap.DeletedFlag} flags := []string{imap.SeenFlag, imap.DeletedFlag}
permFlags := []string{imap.SeenFlag} permFlags := []string{imap.SeenFlag}
for _, flag := range mbox.u.flags { for _, flag := range mbox.u.flags {
flags = append(flags, flag) flags = append(flags, flag)
permFlags = append(permFlags, flag) permFlags = append(permFlags, flag)
} }
mbox.u.locker.Unlock() mbox.u.Unlock()
status := imap.NewMailboxStatus(mbox.name, items) status := imap.NewMailboxStatus(mbox.name, items)
status.Flags = flags status.Flags = flags
status.PermanentFlags = permFlags status.PermanentFlags = permFlags
status.UnseenSeqNum = 0 // TODO status.UnseenSeqNum = 0 // TODO
mbox.Lock()
defer mbox.Unlock()
for _, name := range items { for _, name := range items {
switch name { switch name {
case imap.StatusMessages: case imap.StatusMessages:
@ -141,13 +144,12 @@ func (mbox *mailbox) sync() error {
} }
log.Printf("Synchronizing mailbox %v: done.", mbox.name) log.Printf("Synchronizing mailbox %v: done.", mbox.name)
return nil return nil
} }
func (mbox *mailbox) init() error { func (mbox *mailbox) init() error {
mbox.initializedLock.Lock() mbox.Lock()
defer mbox.initializedLock.Unlock() defer mbox.Unlock()
if mbox.initialized { if mbox.initialized {
return nil return nil
@ -163,11 +165,10 @@ func (mbox *mailbox) init() error {
} }
func (mbox *mailbox) reset() error { func (mbox *mailbox) reset() error {
mbox.initializedLock.Lock() mbox.Lock()
defer mbox.initializedLock.Unlock() defer mbox.Unlock()
mbox.initialized = false mbox.initialized = false
return mbox.db.Reset() return mbox.db.Reset()
} }
@ -179,14 +180,21 @@ func (mbox *mailbox) fetchFlags(msg *protonmail.Message) []string {
if msg.IsReplied != 0 || msg.IsRepliedAll != 0 { if msg.IsReplied != 0 || msg.IsRepliedAll != 0 {
flags = append(flags, imap.AnsweredFlag) flags = append(flags, imap.AnsweredFlag)
} }
mbox.Lock()
if _, ok := mbox.deleted[msg.ID]; ok { if _, ok := mbox.deleted[msg.ID]; ok {
flags = append(flags, imap.DeletedFlag) flags = append(flags, imap.DeletedFlag)
} }
mbox.Unlock()
mbox.u.Lock()
for _, label := range msg.LabelIDs { for _, label := range msg.LabelIDs {
if flag, ok := mbox.u.flags[label]; ok { if flag, ok := mbox.u.flags[label]; ok {
flags = append(flags, flag) flags = append(flags, flag)
} }
} }
mbox.u.Unlock()
return flags return flags
} }
@ -301,7 +309,7 @@ func (mbox *mailbox) SearchMessages(isUID bool, c *imap.SearchCriteria) ([]uint3
// TODO: c.Not, c.Or // TODO: c.Not, c.Or
if c.Not != nil || c.Or != nil { if c.Not != nil || c.Or != nil {
return nil, errors.New("search queries with NOT or OR clauses or not yet implemented") return nil, errors.New("search queries with NOT or OR clauses are not yet implemented")
} }
var results []uint32 var results []uint32
@ -457,6 +465,7 @@ func (mbox *mailbox) UpdateMessagesFlags(uid bool, seqSet *imap.SeqSet, op imap.
} }
case imap.DeletedFlag: case imap.DeletedFlag:
// TODO: send updates // TODO: send updates
mbox.Lock()
switch op { switch op {
case imap.SetFlags, imap.AddFlags: case imap.SetFlags, imap.AddFlags:
for _, apiID := range apiIDs { for _, apiID := range apiIDs {
@ -467,6 +476,7 @@ func (mbox *mailbox) UpdateMessagesFlags(uid bool, seqSet *imap.SeqSet, op imap.
delete(mbox.deleted, apiID) delete(mbox.deleted, apiID)
} }
} }
mbox.Unlock()
case imap.DraftFlag: case imap.DraftFlag:
// No-op // No-op
default: default:
@ -541,9 +551,11 @@ func (mbox *mailbox) Expunge() error {
} }
apiIDs := make([]string, 0, len(mbox.deleted)) apiIDs := make([]string, 0, len(mbox.deleted))
mbox.Lock()
for apiID := range mbox.deleted { for apiID := range mbox.deleted {
apiIDs = append(apiIDs, apiID) apiIDs = append(apiIDs, apiID)
} }
mbox.Unlock()
if err := mbox.u.c.DeleteMessages(apiIDs); err != nil { if err := mbox.u.c.DeleteMessages(apiIDs); err != nil {
return err return err

View File

@ -44,22 +44,25 @@ type user struct {
u *protonmail.User u *protonmail.User
privateKeys openpgp.EntityList privateKeys openpgp.EntityList
addrs []*protonmail.Address addrs []*protonmail.Address
numClients int
db *database.User db *database.User
eventsReceiver *events.Receiver eventsReceiver *events.Receiver
locker sync.Mutex
mailboxes map[string]*mailbox // indexed by label ID
flags map[string]string // indexed by label ID
done chan<- struct{} done chan<- struct{}
eventSent chan struct{} eventSent chan struct{}
sync.Mutex // protects everything below
numClients int
mailboxes map[string]*mailbox // indexed by label ID
flags map[string]string // indexed by label ID
} }
func getUser(be *backend, username string, c *protonmail.Client, privateKeys openpgp.EntityList) (*user, error) { func getUser(be *backend, username string, c *protonmail.Client, privateKeys openpgp.EntityList) (*user, error) {
if u, ok := be.users[username]; ok { if u, ok := be.users[username]; ok {
u.Lock()
u.numClients++ u.numClients++
u.Unlock()
return u, nil return u, nil
} else { } else {
pu, err := c.GetCurrentUser() pu, err := c.GetCurrentUser()
@ -141,8 +144,8 @@ func labelNameToFlag(s string) string {
} }
func (u *user) initMailboxes() error { func (u *user) initMailboxes() error {
u.locker.Lock() u.Lock()
defer u.locker.Unlock() defer u.Unlock()
u.mailboxes = make(map[string]*mailbox) u.mailboxes = make(map[string]*mailbox)
for _, data := range systemMailboxes { for _, data := range systemMailboxes {
@ -202,8 +205,8 @@ func (u *user) Username() string {
} }
func (u *user) ListMailboxes(subscribed bool) ([]imapbackend.Mailbox, error) { func (u *user) ListMailboxes(subscribed bool) ([]imapbackend.Mailbox, error) {
u.locker.Lock() u.Lock()
defer u.locker.Unlock() defer u.Unlock()
list := make([]imapbackend.Mailbox, 0, len(u.mailboxes)) list := make([]imapbackend.Mailbox, 0, len(u.mailboxes))
for _, mbox := range u.mailboxes { for _, mbox := range u.mailboxes {
@ -213,14 +216,14 @@ func (u *user) ListMailboxes(subscribed bool) ([]imapbackend.Mailbox, error) {
} }
func (u *user) getMailboxByLabel(labelID string) *mailbox { func (u *user) getMailboxByLabel(labelID string) *mailbox {
u.locker.Lock() u.Lock()
defer u.locker.Unlock() defer u.Unlock()
return u.mailboxes[labelID] return u.mailboxes[labelID]
} }
func (u *user) getMailbox(name string) *mailbox { func (u *user) getMailbox(name string) *mailbox {
u.locker.Lock() u.Lock()
defer u.locker.Unlock() defer u.Unlock()
for _, mbox := range u.mailboxes { for _, mbox := range u.mailboxes {
if mbox.name == name { if mbox.name == name {
@ -239,8 +242,8 @@ func (u *user) GetMailbox(name string) (imapbackend.Mailbox, error) {
} }
func (u *user) getFlag(name string) string { func (u *user) getFlag(name string) string {
u.locker.Lock() u.Lock()
defer u.locker.Unlock() defer u.Unlock()
for label, flag := range u.flags { for label, flag := range u.flags {
if flag == name { if flag == name {
@ -263,6 +266,9 @@ func (u *user) RenameMailbox(existingName, newName string) error {
} }
func (u *user) Logout() error { func (u *user) Logout() error {
u.Lock()
defer u.Unlock()
if u.numClients <= 0 { if u.numClients <= 0 {
panic("unreachable") panic("unreachable")
} }
@ -298,13 +304,13 @@ func (u *user) receiveEvents(updates chan<- imapbackend.Update, events <-chan *p
if event.Refresh&protonmail.EventRefreshMail != 0 { if event.Refresh&protonmail.EventRefreshMail != 0 {
log.Println("Reinitializing the whole IMAP database") log.Println("Reinitializing the whole IMAP database")
u.locker.Lock() u.Lock()
for _, mbox := range u.mailboxes { for _, mbox := range u.mailboxes {
if err := mbox.reset(); err != nil { if err := mbox.reset(); err != nil {
log.Printf("cannot reset mailbox %s: %v", mbox.name, err) log.Printf("cannot reset mailbox %s: %v", mbox.name, err)
} }
} }
u.locker.Unlock() u.Unlock()
if err := u.db.ResetMessages(); err != nil { if err := u.db.ResetMessages(); err != nil {
log.Printf("cannot reset user: %v", err) log.Printf("cannot reset user: %v", err)
@ -406,14 +412,14 @@ func (u *user) receiveEvents(updates chan<- imapbackend.Update, events <-chan *p
} }
} }
u.locker.Lock() u.Lock()
for _, count := range event.MessageCounts { for _, count := range event.MessageCounts {
if mbox, ok := u.mailboxes[count.LabelID]; ok { if mbox, ok := u.mailboxes[count.LabelID]; ok {
mbox.total = count.Total mbox.total = count.Total
mbox.unread = count.Unread mbox.unread = count.Unread
} }
} }
u.locker.Unlock() u.Unlock()
} }
for _, update := range eventUpdates { for _, update := range eventUpdates {