Skip to content

Commit

Permalink
fix(GODT-3117): Improve Contact Info Retrieval
Browse files Browse the repository at this point in the history
Rather than only fetching the total in on request and discarding all the
data, re-use the first page of data and then collect more of them if the
data set exceeds the page size.

This patch also includes various fixes to the GPA server to mimic
proton server behavior.
  • Loading branch information
LBeernaertProton committed Nov 20, 2023
1 parent 8a47c8d commit e90ebad
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 48 deletions.
84 changes: 72 additions & 12 deletions contact.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,15 @@ func (c *Client) CountContactEmails(ctx context.Context, email string) (int, err
}

func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) {
_, contacts, err := c.getContactsImpl(ctx, page, pageSize)

return contacts, err
}

func (c *Client) getContactsImpl(ctx context.Context, page, pageSize int) (int, []Contact, error) {
var res struct {
Contacts []Contact
Total int
}

if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
Expand All @@ -60,26 +67,58 @@ func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact
"PageSize": strconv.Itoa(pageSize),
}).SetResult(&res).Get("/contacts/v4")
}); err != nil {
return nil, err
return 0, nil, err
}

return res.Contacts, nil
return res.Total, res.Contacts, nil
}

func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) {
total, err := c.CountContacts(ctx)
return c.GetAllContactsPaged(ctx, maxPageSize)
}

func (c *Client) GetAllContactsPaged(ctx context.Context, pageSize int) ([]Contact, error) {
if pageSize > maxPageSize {
pageSize = maxPageSize
}

total, firstBatch, err := c.getContactsImpl(ctx, 0, pageSize)
if err != nil {
return nil, err
}

return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]Contact, error) {
return c.GetContacts(ctx, page, pageSize)
})
if total <= pageSize {
return firstBatch, nil
}

remainingPages := (total / pageSize) + 1

for i := 1; i < remainingPages; i++ {
_, batch, err := c.getContactsImpl(ctx, i, pageSize)
if err != nil {
return nil, err
}

firstBatch = append(firstBatch, batch...)
}

return firstBatch, err
}

func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) {
if pageSize > maxPageSize {
pageSize = maxPageSize
}

_, contacts, err := c.getContactEmailsImpl(ctx, email, page, pageSize)

return contacts, err
}

func (c *Client) getContactEmailsImpl(ctx context.Context, email string, page, pageSize int) (int, []ContactEmail, error) {
var res struct {
ContactEmails []ContactEmail
Total int
}

if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
Expand All @@ -89,21 +128,42 @@ func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageS
"Email": email,
}).SetResult(&res).Get("/contacts/v4/emails")
}); err != nil {
return nil, err
return 0, nil, err
}

return res.ContactEmails, nil
return res.Total, res.ContactEmails, nil
}

func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) {
total, err := c.CountContactEmails(ctx, email)
return c.GetAllContactEmailsPaged(ctx, email, maxPageSize)
}

func (c *Client) GetAllContactEmailsPaged(ctx context.Context, email string, pageSize int) ([]ContactEmail, error) {
if pageSize > maxPageSize {
pageSize = maxPageSize
}

total, firstBatch, err := c.getContactEmailsImpl(ctx, email, 0, pageSize)
if err != nil {
return nil, err
}

return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) {
return c.GetContactEmails(ctx, email, page, pageSize)
})
if total <= pageSize {
return firstBatch, nil
}

remainingPages := (total / pageSize) + 1

for i := 1; i < remainingPages; i++ {
_, batch, err := c.getContactEmailsImpl(ctx, email, i, pageSize)
if err != nil {
return nil, err
}

firstBatch = append(firstBatch, batch...)
}

return firstBatch, err
}

func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) {
Expand Down
11 changes: 11 additions & 0 deletions contact_card.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ func (c *Card) Set(kr *crypto.KeyRing, key string, value *vcard.Field) error {
return c.encode(kr, dec)
}

func (c *Card) Add(kr *crypto.KeyRing, key string, value *vcard.Field) error {
dec, err := c.decode(kr)
if err != nil {
return err
}

dec.Add(key, value)

return c.encode(kr, dec)
}

func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error {
dec, err := c.decode(kr)
if err != nil {
Expand Down
13 changes: 7 additions & 6 deletions server/backend/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (
)

type account struct {
userID string
username string
addresses map[string]*address
mailSettings *mailSettings
userSettings proton.UserSettings
contacts map[string]*proton.Contact
userID string
username string
addresses map[string]*address
mailSettings *mailSettings
userSettings proton.UserSettings
contacts map[string]*proton.Contact
contactCounter int

auth map[string]auth
authLock sync.RWMutex
Expand Down
44 changes: 25 additions & 19 deletions server/backend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,18 +1106,22 @@ func (b *Backend) GetUserContact(userID, contactID string) (proton.Contact, erro
})
}

func (b *Backend) GetUserContacts(userID string) ([]proton.Contact, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
var contacts []proton.Contact
for _, contact := range acc.contacts {
contacts = append(contacts, *contact)
}
return contacts, nil
func (b *Backend) GetUserContacts(userID string, page int, pageSize int) (int, []proton.Contact, error) {
var total int
contacts, err := withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
total = len(acc.contacts)
return xslices.Map(xslices.Chunk(maps.Values(acc.contacts), pageSize)[page], func(c *proton.Contact) proton.Contact {
return *c
}), nil
})

return total, contacts, err
}

func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEmail, error) {
return withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) {
func (b *Backend) GetUserContactEmails(userID, email string, page int, pageSize int) (int, []proton.ContactEmail, error) {
var total int

emails, err := withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) {
var contacts []proton.ContactEmail
for _, contact := range acc.contacts {
for _, contactEmail := range contact.ContactEmails {
Expand All @@ -1126,8 +1130,17 @@ func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEm
}
}
}
return contacts, nil

total = len(contacts)

if total < pageSize {
return contacts, nil
}

return xslices.Chunk(contacts, pageSize)[page], nil
})

return total, emails, err
}

func (b *Backend) AddUserContact(userID string, contact proton.Contact) (proton.Contact, error) {
Expand All @@ -1146,15 +1159,8 @@ func (b *Backend) UpdateUserContact(userID, contactID string, cards proton.Cards

func (b *Backend) GenerateContactID(userID string) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
var lastKey = "0"
for k := range acc.contacts {
lastKey = k
}
newKey, err := strconv.Atoi(lastKey)
if err != nil {
return "", err
}
return strconv.Itoa(newKey + 1), nil
acc.contactCounter++
return strconv.Itoa(acc.contactCounter), nil
})
}

Expand Down
21 changes: 14 additions & 7 deletions server/backend/contact.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ package backend
import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-vcard"
"strconv"
"sync/atomic"
)

var globalContactID int32

func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRing) (proton.Contact, error) {
emails, err := card.Get(kr, vcard.FieldEmail)
if err != nil {
Expand All @@ -19,13 +24,15 @@ func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRin
ContactMetadata: proton.ContactMetadata{
ID: contactID,
Name: names[0].Value,
ContactEmails: []proton.ContactEmail{proton.ContactEmail{
ID: "1",
Name: names[0].Value,
Email: emails[0].Value,
ContactID: contactID,
},
},
ContactEmails: xslices.Map(emails, func(email *vcard.Field) proton.ContactEmail {
id := atomic.AddInt32(&globalContactID, 1)
return proton.ContactEmail{
ID: strconv.Itoa(int(id)),
Name: names[0].Value,
Email: email.Value,
ContactID: contactID,
}
}),
},
ContactCards: proton.ContactCards{Cards: proton.Cards{card}},
}, nil
Expand Down
17 changes: 13 additions & 4 deletions server/contacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"net/http"
"strconv"

"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server/backend"
Expand All @@ -10,29 +11,37 @@ import (

func (s *Server) handleGetContacts() gin.HandlerFunc {
return func(c *gin.Context) {
contacts, err := s.b.GetUserContacts(c.GetString("UserID"))
total, contacts, err := s.b.GetUserContacts(c.GetString("UserID"),
mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))),
mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))),
)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}

c.JSON(http.StatusOK, gin.H{
"Code": proton.MultiCode,
"ContactEmails": contacts,
"Code": proton.MultiCode,
"Contacts": contacts,
"Total": total,
})
}
}

func (s *Server) handleGetContactsEmails() gin.HandlerFunc {
return func(c *gin.Context) {
contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.GetString("email"))
total, contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.Query("Email"),
mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))),
mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))),
)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
c.JSON(http.StatusOK, gin.H{
"Code": proton.MultiCode,
"ContactEmails": contacts,
"Total": total,
})
}
}
Expand Down
Loading

0 comments on commit e90ebad

Please sign in to comment.