Skip to content

Commit

Permalink
feat(GODT-2567): Simulate Answered/Forwarded behavior in GPA server
Browse files Browse the repository at this point in the history
  • Loading branch information
LBeernaertProton committed Nov 16, 2023
1 parent c9bc6f7 commit 2624cd3
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 18 deletions.
24 changes: 22 additions & 2 deletions server/backend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,19 +636,21 @@ func (b *Backend) DeleteMessage(userID, messageID string) error {
})
}

func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string) (proton.Message, error) {
func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string, action proton.CreateDraftAction) (proton.Message, error) {
return withAcc(b, userID, func(acc *account) (proton.Message, error) {
return withMessages(b, func(messages map[string]*message) (proton.Message, error) {
return withLabels(b, func(labels map[string]*label) (proton.Message, error) {
// Convert the parentID into externalRef.\
var parentRef string
var internalParentID string
if parentID != "" {
parentMsg, ok := messages[parentID]
if ok {
parentRef = "<" + parentMsg.externalID + ">"
internalParentID = parentID
}
}
msg := newMessageFromTemplate(addrID, draft, parentRef)
msg := newMessageFromTemplate(addrID, draft, parentRef, internalParentID, action)
// Drafts automatically get the sysLabel "Drafts".
msg.addLabel(proton.DraftsLabel, labels)

Expand Down Expand Up @@ -712,6 +714,24 @@ func (b *Backend) SendMessage(userID, messageID string, packages []*proton.Messa
msg.flags |= proton.MessageFlagSent
msg.addLabel(proton.SentLabel, labels)

if parent, ok := messages[msg.internalParentID]; ok {
switch msg.draftAction {
case proton.ReplyAction:
parent.flags |= proton.MessageFlagReplied
case proton.ReplyAllAction:
parent.flags |= proton.MessageFlagRepliedAll
case proton.ForwardAction:
parent.flags |= proton.MessageFlagForwarded
}

updateID, err := b.newUpdate(&messageUpdated{messageID: msg.internalParentID})
if err != nil {
return proton.Message{}, err
}

acc.updateIDs = append(acc.updateIDs, updateID)
}

updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
if err != nil {
return proton.Message{}, err
Expand Down
36 changes: 24 additions & 12 deletions server/backend/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ import (
)

type message struct {
messageID string
externalID string
addrID string
labelIDs []string
attIDs []string
inReplyTo string
messageID string
externalID string
addrID string
labelIDs []string
attIDs []string
inReplyTo string
internalParentID string

// sysLabel is the system label for the message.
// If nil, the message's flags are used to determine the system label (inbox, sent, drafts).
Expand All @@ -34,6 +35,8 @@ type message struct {
replytos []*mail.Address
date time.Time

draftAction proton.CreateDraftAction

armBody string
mimeType rfc822.MIMEType

Expand Down Expand Up @@ -92,13 +95,20 @@ func newMessageFromSent(addrID, armBody string, msg *message) *message {
}
}

func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parentRef string) *message {
func newMessageFromTemplate(
addrID string,
template proton.DraftTemplate,
parentRef string,
internalParentID string,
action proton.CreateDraftAction,
) *message {
return &message{
messageID: uuid.NewString(),
externalID: template.ExternalID,
addrID: addrID,
sysLabel: pointer(""),
inReplyTo: parentRef,
messageID: uuid.NewString(),
externalID: template.ExternalID,
addrID: addrID,
sysLabel: pointer(""),
inReplyTo: parentRef,
internalParentID: internalParentID,

subject: template.Subject,
sender: template.Sender,
Expand All @@ -107,6 +117,8 @@ func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parent
bccList: template.BCCList,
unread: bool(template.Unread),

draftAction: action,

armBody: template.Body,
mimeType: template.MIMEType,
}
Expand Down
2 changes: 1 addition & 1 deletion server/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (s *Server) postMailMessages(c *gin.Context) {
return
}

message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID)
message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID, req.Action)
if err != nil {
c.AbortWithStatus(http.StatusUnprocessableEntity)
return
Expand Down
88 changes: 85 additions & 3 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/ProtonMail/go-proton-api/server/backend"
"net/http"
"net/mail"
"net/url"
Expand All @@ -17,14 +16,14 @@ import (
"testing"
"time"

"github.com/bradenaw/juniper/parallel"

"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
Expand Down Expand Up @@ -2232,6 +2231,89 @@ func TestServer_GetMessageGroupCount(t *testing.T) {
})
}

func TestServer_TestDraftActions(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

user, err := c.GetUser(ctx)
require.NoError(t, err)

addr, err := c.GetAddresses(ctx)
require.NoError(t, err)

salt, err := c.GetSalts(ctx)
require.NoError(t, err)

pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)

_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
require.NoError(t, err)

type testData struct {
action proton.CreateDraftAction
flag proton.MessageFlag
}

tests := []testData{
{
action: proton.ReplyAction,
flag: proton.MessageFlagReplied,
},
{
action: proton.ReplyAllAction,
flag: proton.MessageFlagRepliedAll,
},
{
action: proton.ForwardAction,
flag: proton.MessageFlagForwarded,
},
}

importedMessages := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, 0, len(tests))

for i := 0; i < len(tests); i++ {
importedMessageID := importedMessages[i].MessageID

msg, err := c.GetMessage(ctx, importedMessageID)
require.NoError(t, err)

{
kr := addrKRs[addr[0].ID]
msg, err := c.CreateDraft(ctx, kr, proton.CreateDraftReq{
Message: proton.DraftTemplate{
Subject: "Foo",
Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "foo@bar"}},
CCList: nil,
BCCList: nil,
},
AttachmentKeyPackets: nil,
ParentID: msg.ID,
Action: tests[i].action,
})

require.NoError(t, err)

var sreq proton.SendDraftReq

require.NoError(t, sreq.AddTextPackage(kr, "Hello", "text/plain", map[string]proton.SendPreferences{}, map[string]*crypto.SessionKey{}))

_, err = c.SendDraft(ctx, msg.ID, sreq)
require.NoError(t, err)

msg, err = c.GetMessage(ctx, importedMessageID)
require.NoError(t, err)
require.True(t, msg.Flags&tests[i].flag != 0)
}
}

})
})
}

func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down

0 comments on commit 2624cd3

Please sign in to comment.