diff --git a/contact.go b/contact.go index 8c8abcf..8f4daf6 100644 --- a/contact.go +++ b/contact.go @@ -50,8 +50,15 @@ func (c *Client) CountContactEmails(ctx context.Context, email string) (int, err } func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) { + _, contacts, err := c.getContactsImpl(ctx, page, pageSize) + + return contacts, err +} + +func (c *Client) getContactsImpl(ctx context.Context, page, pageSize int) (int, []Contact, error) { var res struct { Contacts []Contact + Total int } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { @@ -60,26 +67,58 @@ func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact "PageSize": strconv.Itoa(pageSize), }).SetResult(&res).Get("/contacts/v4") }); err != nil { - return nil, err + return 0, nil, err } - return res.Contacts, nil + return res.Total, res.Contacts, nil } func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) { - total, err := c.CountContacts(ctx) + return c.GetAllContactsPaged(ctx, maxPageSize) +} + +func (c *Client) GetAllContactsPaged(ctx context.Context, pageSize int) ([]Contact, error) { + if pageSize > maxPageSize { + pageSize = maxPageSize + } + + total, firstBatch, err := c.getContactsImpl(ctx, 0, pageSize) if err != nil { return nil, err } - return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]Contact, error) { - return c.GetContacts(ctx, page, pageSize) - }) + if total <= pageSize { + return firstBatch, nil + } + + remainingPages := (total / pageSize) + 1 + + for i := 1; i < remainingPages; i++ { + _, batch, err := c.getContactsImpl(ctx, i, pageSize) + if err != nil { + return nil, err + } + + firstBatch = append(firstBatch, batch...) + } + + return firstBatch, err } func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) { + if pageSize > maxPageSize { + pageSize = maxPageSize + } + + _, contacts, err := c.getContactEmailsImpl(ctx, email, page, pageSize) + + return contacts, err +} + +func (c *Client) getContactEmailsImpl(ctx context.Context, email string, page, pageSize int) (int, []ContactEmail, error) { var res struct { ContactEmails []ContactEmail + Total int } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { @@ -89,21 +128,42 @@ func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageS "Email": email, }).SetResult(&res).Get("/contacts/v4/emails") }); err != nil { - return nil, err + return 0, nil, err } - return res.ContactEmails, nil + return res.Total, res.ContactEmails, nil } func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) { - total, err := c.CountContactEmails(ctx, email) + return c.GetAllContactEmailsPaged(ctx, email, maxPageSize) +} + +func (c *Client) GetAllContactEmailsPaged(ctx context.Context, email string, pageSize int) ([]ContactEmail, error) { + if pageSize > maxPageSize { + pageSize = maxPageSize + } + + total, firstBatch, err := c.getContactEmailsImpl(ctx, email, 0, pageSize) if err != nil { return nil, err } - return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) { - return c.GetContactEmails(ctx, email, page, pageSize) - }) + if total <= pageSize { + return firstBatch, nil + } + + remainingPages := (total / pageSize) + 1 + + for i := 1; i < remainingPages; i++ { + _, batch, err := c.getContactEmailsImpl(ctx, email, i, pageSize) + if err != nil { + return nil, err + } + + firstBatch = append(firstBatch, batch...) + } + + return firstBatch, err } func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) { diff --git a/contact_card.go b/contact_card.go index 16ae99d..769f4c1 100644 --- a/contact_card.go +++ b/contact_card.go @@ -109,6 +109,17 @@ func (c *Card) Set(kr *crypto.KeyRing, key string, value *vcard.Field) error { return c.encode(kr, dec) } +func (c *Card) Add(kr *crypto.KeyRing, key string, value *vcard.Field) error { + dec, err := c.decode(kr) + if err != nil { + return err + } + + dec.Add(key, value) + + return c.encode(kr, dec) +} + func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error { dec, err := c.decode(kr) if err != nil { diff --git a/server/backend/account.go b/server/backend/account.go index e46f440..3c501b2 100644 --- a/server/backend/account.go +++ b/server/backend/account.go @@ -10,12 +10,13 @@ import ( ) type account struct { - userID string - username string - addresses map[string]*address - mailSettings *mailSettings - userSettings proton.UserSettings - contacts map[string]*proton.Contact + userID string + username string + addresses map[string]*address + mailSettings *mailSettings + userSettings proton.UserSettings + contacts map[string]*proton.Contact + contactCounter int auth map[string]auth authLock sync.RWMutex diff --git a/server/backend/api.go b/server/backend/api.go index d064cae..0d8a1b3 100644 --- a/server/backend/api.go +++ b/server/backend/api.go @@ -1106,18 +1106,26 @@ func (b *Backend) GetUserContact(userID, contactID string) (proton.Contact, erro }) } -func (b *Backend) GetUserContacts(userID string) ([]proton.Contact, error) { - return withAcc(b, userID, func(acc *account) ([]proton.Contact, error) { - var contacts []proton.Contact - for _, contact := range acc.contacts { - contacts = append(contacts, *contact) - } - return contacts, nil +func (b *Backend) GetUserContacts(userID string, page int, pageSize int) (int, []proton.Contact, error) { + var total int + contacts, err := withAcc(b, userID, func(acc *account) ([]proton.Contact, error) { + total = len(acc.contacts) + values := maps.Values(acc.contacts) + slices.SortFunc(values, func(i, j *proton.Contact) bool { + return strings.Compare(i.ID, j.ID) < 0 + }) + return xslices.Map(xslices.Chunk(values, pageSize)[page], func(c *proton.Contact) proton.Contact { + return *c + }), nil }) + + return total, contacts, err } -func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEmail, error) { - return withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) { +func (b *Backend) GetUserContactEmails(userID, email string, page int, pageSize int) (int, []proton.ContactEmail, error) { + var total int + + emails, err := withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) { var contacts []proton.ContactEmail for _, contact := range acc.contacts { for _, contactEmail := range contact.ContactEmails { @@ -1126,8 +1134,21 @@ func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEm } } } - return contacts, nil + + total = len(contacts) + + if total < pageSize { + return contacts, nil + } + + slices.SortFunc(contacts, func(a, b proton.ContactEmail) bool { + return strings.Compare(a.ID, b.ID) < 0 + }) + + return xslices.Chunk(contacts, pageSize)[page], nil }) + + return total, emails, err } func (b *Backend) AddUserContact(userID string, contact proton.Contact) (proton.Contact, error) { @@ -1146,15 +1167,8 @@ func (b *Backend) UpdateUserContact(userID, contactID string, cards proton.Cards func (b *Backend) GenerateContactID(userID string) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) { - var lastKey = "0" - for k := range acc.contacts { - lastKey = k - } - newKey, err := strconv.Atoi(lastKey) - if err != nil { - return "", err - } - return strconv.Itoa(newKey + 1), nil + acc.contactCounter++ + return strconv.Itoa(acc.contactCounter), nil }) } diff --git a/server/backend/contact.go b/server/backend/contact.go index 158d7e7..06a94e9 100644 --- a/server/backend/contact.go +++ b/server/backend/contact.go @@ -3,9 +3,14 @@ package backend import ( "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-vcard" + "strconv" + "sync/atomic" ) +var globalContactID int32 + func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRing) (proton.Contact, error) { emails, err := card.Get(kr, vcard.FieldEmail) if err != nil { @@ -19,13 +24,15 @@ func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRin ContactMetadata: proton.ContactMetadata{ ID: contactID, Name: names[0].Value, - ContactEmails: []proton.ContactEmail{proton.ContactEmail{ - ID: "1", - Name: names[0].Value, - Email: emails[0].Value, - ContactID: contactID, - }, - }, + ContactEmails: xslices.Map(emails, func(email *vcard.Field) proton.ContactEmail { + id := atomic.AddInt32(&globalContactID, 1) + return proton.ContactEmail{ + ID: strconv.Itoa(int(id)), + Name: names[0].Value, + Email: email.Value, + ContactID: contactID, + } + }), }, ContactCards: proton.ContactCards{Cards: proton.Cards{card}}, }, nil diff --git a/server/contacts.go b/server/contacts.go index 17dae7f..547515b 100644 --- a/server/contacts.go +++ b/server/contacts.go @@ -2,6 +2,7 @@ package server import ( "net/http" + "strconv" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server/backend" @@ -10,22 +11,29 @@ import ( func (s *Server) handleGetContacts() gin.HandlerFunc { return func(c *gin.Context) { - contacts, err := s.b.GetUserContacts(c.GetString("UserID")) + total, contacts, err := s.b.GetUserContacts(c.GetString("UserID"), + mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))), + mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))), + ) if err != nil { c.AbortWithStatus(http.StatusBadRequest) return } c.JSON(http.StatusOK, gin.H{ - "Code": proton.MultiCode, - "ContactEmails": contacts, + "Code": proton.MultiCode, + "Contacts": contacts, + "Total": total, }) } } func (s *Server) handleGetContactsEmails() gin.HandlerFunc { return func(c *gin.Context) { - contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.GetString("email")) + total, contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.Query("Email"), + mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))), + mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))), + ) if err != nil { c.AbortWithStatus(http.StatusBadRequest) return @@ -33,6 +41,7 @@ func (s *Server) handleGetContactsEmails() gin.HandlerFunc { c.JSON(http.StatusOK, gin.H{ "Code": proton.MultiCode, "ContactEmails": contacts, + "Total": total, }) } } diff --git a/server/server_test.go b/server/server_test.go index 50c2cd0..a3c0bef 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -26,7 +26,9 @@ import ( "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" + "github.com/emersion/go-vcard" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" ) @@ -2314,6 +2316,219 @@ func TestServer_TestDraftActions(t *testing.T) { }) } +func TestServer_Contacts(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) + require.NoError(t, err) + + type testContact struct { + Name string + Email string + } + + testContacts := []testContact{ + { + Name: "foo", + Email: "foo@bar.com", + }, + { + Name: "bar", + Email: "bar@bar.com", + }, + { + Name: "zz", + Email: "zz@bar.com", + }, + } + + contactDesc := []proton.ContactCards{ + { + Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card { + return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Email) + }), + }, + } + createReq := proton.CreateContactsReq{ + Contacts: contactDesc, + Overwrite: 0, + Labels: 0, + } + + contactsRes, err := c.CreateContacts(ctx, createReq) + require.NoError(t, err) + assert.Equal(t, 3, len(contactsRes)) + + contacts, err := c.GetAllContactsPaged(ctx, 2) + require.NoError(t, err) + require.Len(t, contacts, len(testContacts)) + + for _, v := range testContacts { + require.NotEqual(t, -1, xslices.IndexFunc(contacts, func(contact proton.Contact) bool { + return contact.Name == v.Name + })) + } + }) + }) +} + +func TestServer_ContactEmails(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) + require.NoError(t, err) + + type testContact struct { + Name string + Emails []string + } + + testContacts := []testContact{ + { + Name: "foo", + Emails: []string{"foo@bar.com", "alias@alias.com", "nn@zz.com", "abc@4.de", "001234@00.com"}, + }, + { + Name: "bar", + Emails: []string{"bar@bar.com"}, + }, + { + Name: "zz", + Emails: []string{"zz@bar.com", "zz@zz2.com"}, + }, + } + + contactDesc := []proton.ContactCards{ + { + Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card { + return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Emails...) + }), + }, + } + createReq := proton.CreateContactsReq{ + Contacts: contactDesc, + Overwrite: 0, + Labels: 0, + } + + contactsRes, err := c.CreateContacts(ctx, createReq) + require.NoError(t, err) + assert.Equal(t, 3, len(contactsRes)) + + for _, v := range testContacts { + for _, email := range v.Emails { + emails, err := c.GetAllContactEmailsPaged(ctx, email, 2) + require.NoError(t, err) + require.Len(t, emails, 1) + assert.Equal(t, email, emails[0].Email) + } + } + }) + }) +} + +func TestServer_ContactEmailsRepeated(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) + require.NoError(t, err) + + type testContact struct { + Name string + Emails []string + } + + testContacts := []testContact{ + { + Name: "foo", + Emails: []string{"foo@bar.com"}, + }, + { + Name: "bar", + Emails: []string{"foo@bar.com"}, + }, + { + Name: "zz", + Emails: []string{"foo@bar.com"}, + }, + } + + contactDesc := []proton.ContactCards{ + { + Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card { + return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Emails...) + }), + }, + } + createReq := proton.CreateContactsReq{ + Contacts: contactDesc, + Overwrite: 0, + Labels: 0, + } + + contactsRes, err := c.CreateContacts(ctx, createReq) + require.NoError(t, err) + assert.Equal(t, 3, len(contactsRes)) + + emails, err := c.GetAllContactEmailsPaged(ctx, "foo@bar.com", 2) + require.NoError(t, err) + require.Len(t, emails, len(testContacts)) + }) + }) +} + +func createVCard(t *testing.T, addrKR *crypto.KeyRing, name string, email ...string) *proton.Card { + card, err := proton.NewCard(addrKR, proton.CardTypeSigned) + require.NoError(t, err) + + require.NoError(t, card.Set(addrKR, vcard.FieldUID, &vcard.Field{Value: fmt.Sprintf("proton-legacy-%v", uuid.NewString()), Group: "test"})) + require.NoError(t, card.Set(addrKR, vcard.FieldFormattedName, &vcard.Field{Value: name, Group: "test"})) + for _, email := range email { + require.NoError(t, card.Add(addrKR, vcard.FieldEmail, &vcard.Field{Value: email, Group: "test"})) + } + + return card +} + func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()