Skip to content

Commit

Permalink
Implement a unified query response subject
Browse files Browse the repository at this point in the history
Implement the changes explained in overmindtech/sdp#80

ref overmindtech/sdp#79
  • Loading branch information
DavidS-ovm committed May 29, 2023
1 parent 300355f commit 44c5c6a
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 154 deletions.
39 changes: 37 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package sdp
import (
"context"
"fmt"
reflect "reflect"

"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -39,7 +42,23 @@ type EncodedConnectionImpl struct {
// assert interface implementation
var _ EncodedConnection = (*EncodedConnectionImpl)(nil)

func recordMessage(ctx context.Context, name, subj, typ, msg string) {
log.WithContext(ctx).WithFields(log.Fields{
"msg type": typ,
"subj": subj,
"msg": msg,
}).Trace(name)
span := trace.SpanFromContext(ctx)
span.AddEvent(name, trace.WithAttributes(
attribute.String("om.sdp.subject", subj),
attribute.String("om.sdp.message", msg),
))
}

func (ec *EncodedConnectionImpl) Publish(ctx context.Context, subj string, m proto.Message) error {
// TODO: protojson.Format is pretty expensive, replace with summarized data
recordMessage(ctx, "Publish", subj, fmt.Sprint(reflect.TypeOf(m)), protojson.Format(m))

data, err := proto.Marshal(m)
if err != nil {
return err
Expand All @@ -49,10 +68,13 @@ func (ec *EncodedConnectionImpl) Publish(ctx context.Context, subj string, m pro
Subject: subj,
Data: data,
}
return ec.PublishMsg(ctx, msg)
InjectOtelTraceContext(ctx, msg)
return ec.Conn.PublishMsg(msg)
}

func (ec *EncodedConnectionImpl) PublishMsg(ctx context.Context, msg *nats.Msg) error {
recordMessage(ctx, "Publish", msg.Subject, "[]byte", "binary")

InjectOtelTraceContext(ctx, msg)
return ec.Conn.PublishMsg(msg)
}
Expand All @@ -68,8 +90,16 @@ func (ec *EncodedConnectionImpl) QueueSubscribe(subj, queue string, cb nats.MsgH
}

func (ec *EncodedConnectionImpl) RequestMsg(ctx context.Context, msg *nats.Msg) (*nats.Msg, error) {
recordMessage(ctx, "RequestMsg", msg.Subject, "[]byte", "binary")
InjectOtelTraceContext(ctx, msg)
return ec.Conn.RequestMsgWithContext(ctx, msg)
reply, err := ec.Conn.RequestMsgWithContext(ctx, msg)

if err != nil {
recordMessage(ctx, "RequestMsg Error", msg.Subject, fmt.Sprint(reflect.TypeOf(err)), err.Error())
} else {
recordMessage(ctx, "RequestMsg Reply", msg.Subject, "[]byte", "binary")
}
return reply, err
}

func (ec *EncodedConnectionImpl) Drain() error {
Expand Down Expand Up @@ -104,6 +134,7 @@ func (ec *EncodedConnectionImpl) Drop() {
func Unmarshal(ctx context.Context, b []byte, m proto.Message) error {
err := proto.Unmarshal(b, m)
if err != nil {
recordMessage(ctx, "Unmarshal err", "unknown", fmt.Sprint(reflect.TypeOf(err)), err.Error())
log.WithContext(ctx).Errorf("Error parsing message: %v", err)
trace.SpanFromContext(ctx).SetStatus(codes.Error, fmt.Sprintf("Error parsing message: %v", err))
return err
Expand All @@ -114,16 +145,20 @@ func Unmarshal(ctx context.Context, b []byte, m proto.Message) error {
// some remaining unknown fields. If there are some, fail.
if unk := m.ProtoReflect().GetUnknown(); unk != nil {
err = fmt.Errorf("unmarshal to %T had unknown fields, likely a type mismatch. Unknowns: %v", m, unk)
recordMessage(ctx, "Unmarshal unknown", "unknown", fmt.Sprint(reflect.TypeOf(m)), protojson.Format(m))
log.WithContext(ctx).Errorf("Error parsing message: %v", err)
trace.SpanFromContext(ctx).SetStatus(codes.Error, fmt.Sprintf("Error parsing message: %v", err))
return err
}

recordMessage(ctx, "Unmarshal", "unknown", fmt.Sprint(reflect.TypeOf(m)), protojson.Format(m))
return nil
}

//go:generate go run genhandler.go Item

//go:generate go run genhandler.go Query
//go:generate go run genhandler.go QueryResponse
//go:generate go run genhandler.go QueryError
//go:generate go run genhandler.go CancelQuery
//go:generate go run genhandler.go UndoQuery
Expand Down
4 changes: 1 addition & 3 deletions encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ var metadata = Metadata{
RecursionBehaviour: &Query_RecursionBehaviour{
LinkDepth: 12,
},
Scope: "testScope",
ItemSubject: "items",
ResponseSubject: "responses",
Scope: "testScope",
},
Timestamp: timestamppb.Now(),
SourceDuration: &durationpb.Duration{
Expand Down
55 changes: 55 additions & 0 deletions handler_queryresponse.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions items.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ func (qrb *Query_RecursionBehaviour) Copy(dest *Query_RecursionBehaviour) {
dest.FollowOnlyBlastPropagation = qrb.FollowOnlyBlastPropagation
}

// Subject returns a NATS subject for all traffic relating to this query
func (q *Query) Subject() string {
return fmt.Sprintf("query.%v", q.ParseUuid())
}

// Copy copies all information from one Query pointer to another
func (q *Query) Copy(dest *Query) {
dest.Type = q.Type
Expand All @@ -293,8 +298,6 @@ func (q *Query) Copy(dest *Query) {
q.RecursionBehaviour.Copy(dest.RecursionBehaviour)
}
dest.Scope = q.Scope
dest.ItemSubject = q.ItemSubject
dest.ResponseSubject = q.ResponseSubject
dest.IgnoreCache = q.IgnoreCache
dest.UUID = q.UUID

Expand Down
112 changes: 27 additions & 85 deletions progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (rs *ResponseSender) Start(ctx context.Context, ec EncodedConnection, respo
rs.connection.Publish(
ctx,
rs.ResponseSubject,
&resp,
&QueryResponse{ResponseType: &QueryResponse_Response{Response: &resp}},
)
}

Expand Down Expand Up @@ -104,7 +104,7 @@ func (rs *ResponseSender) Start(ctx context.Context, ec EncodedConnection, respo
ec.Publish(
ctx,
rs.ResponseSubject,
r,
&QueryResponse{ResponseType: &QueryResponse_Response{Response: r}},
)
}
return
Expand All @@ -118,7 +118,7 @@ func (rs *ResponseSender) Start(ctx context.Context, ec EncodedConnection, respo
ec.Publish(
ctx,
rs.ResponseSubject,
r,
&QueryResponse{ResponseType: &QueryResponse_Response{Response: r}},
)
}
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func (re *Responder) LastStateTime() time.Time {
return re.lastStateTime
}

// QueryProgress represents the status of a request
// QueryProgress represents the status of a query
type QueryProgress struct {
// How long to wait after `MarkStarted()` has been called to get at least
// one responder, if there are no responders in this time, the request will
Expand Down Expand Up @@ -251,10 +251,7 @@ type QueryProgress struct {
cancelled bool
subMutex sync.Mutex

// NATS subscriptions
itemSub *nats.Subscription
responseSub *nats.Subscription
errorSub *nats.Subscription
querySub *nats.Subscription

// Counters for how many things we have sent over the channels. This is
// required to make sure that we aren't closing channels that have pending
Expand Down Expand Up @@ -309,19 +306,6 @@ func (qp *QueryProgress) Start(ctx context.Context, ec EncodedConnection, itemCh

qp.requestCtx = ctx

// Populate inboxes if they aren't already
if qp.Query.ItemSubject == "" {
qp.Query.ItemSubject = fmt.Sprintf("return.item.%v", nats.NewInbox())
}

if qp.Query.ResponseSubject == "" {
qp.Query.ResponseSubject = fmt.Sprintf("return.response.%v", nats.NewInbox())
}

if qp.Query.ErrorSubject == "" {
qp.Query.ErrorSubject = fmt.Sprintf("return.error.%v", nats.NewInbox())
}

if len(qp.Query.UUID) == 0 {
u := uuid.New()
qp.Query.UUID = u[:]
Expand Down Expand Up @@ -350,7 +334,7 @@ func (qp *QueryProgress) Start(ctx context.Context, ec EncodedConnection, itemCh

var err error

qp.itemSub, err = ec.Subscribe(qp.Query.ItemSubject, NewItemHandler("Request.ItemSubject", func(ctx context.Context, item *Item) {
itemHandler := func(ctx context.Context, item *Item) {
defer atomic.AddInt64(qp.itemsProcessed, 1)

span := trace.SpanFromContext(ctx)
Expand Down Expand Up @@ -390,13 +374,9 @@ func (qp *QueryProgress) Start(ctx context.Context, ec EncodedConnection, itemCh

qp.itemChan <- item
}
}))

if err != nil {
return err
}

qp.errorSub, err = ec.Subscribe(qp.Query.ErrorSubject, NewQueryErrorHandler("Request.ErrorSubject", func(ctx context.Context, err *QueryError) {
errorHandler := func(ctx context.Context, err *QueryError) {
defer atomic.AddInt64(qp.errorsProcessed, 1)

if err != nil {
Expand Down Expand Up @@ -432,16 +412,24 @@ func (qp *QueryProgress) Start(ctx context.Context, ec EncodedConnection, itemCh

qp.errorChan <- err
}
}))

if err != nil {
return err
}

qp.responseSub, err = ec.Subscribe(qp.Query.ResponseSubject, NewResponseHandler("ProcessResponse", qp.ProcessResponse))

qp.querySub, err = ec.Subscribe(qp.Query.Subject(), NewQueryResponseHandler("QueryProgress", func(ctx context.Context, qr *QueryResponse) {
log.WithContext(ctx).WithFields(log.Fields{
"response": qr,
}).Info("Received response")
switch qr.ResponseType.(type) {
case *QueryResponse_NewItem:
itemHandler(ctx, qr.GetNewItem())
case *QueryResponse_Error:
errorHandler(ctx, qr.GetError())
case *QueryResponse_Response:
qp.ProcessResponse(ctx, qr.GetResponse())
default:
panic(fmt.Sprintf("Received unexpected QueryResponse: %v", qr))
}
}))
if err != nil {
qp.itemSub.Unsubscribe()
return err
}

Expand Down Expand Up @@ -507,44 +495,7 @@ func (qp *QueryProgress) Drain() {
}

// Close the item and error subscriptions
unsubscribeGracefully(qp.itemSub)
unsubscribeGracefully(qp.errorSub)

if qp.responseSub != nil {
// Drain the response connection to, but don't wait for callbacks to finish.
// this is because this code here is likely called as part of a callback and
// therefore would cause deadlock as it essentially waits for itself to
// finish
qp.responseSub.Unsubscribe()
}

// This double-checks that all callbacks are *definitely* complete to avoid
// a situation where we close the channel with a goroutine still pending a
// send. This is rare due to the use of RWMutex on the channel, but still
// possible
var itemsDelivered int64
var errorsDelivered int64
var err error

for {
itemsDelivered, err = qp.itemSub.Delivered()

if err != nil {
break
}

errorsDelivered, err = qp.errorSub.Delivered()

if err != nil {
break
}

if (itemsDelivered == *qp.itemsProcessed) && (errorsDelivered == *qp.errorsProcessed) {
break
}

time.Sleep(50 * time.Millisecond)
}
unsubscribeGracefully(qp.querySub)

qp.chanMutex.Lock()
defer qp.chanMutex.Unlock()
Expand Down Expand Up @@ -923,30 +874,21 @@ func stallMonitor(context context.Context, timeout time.Duration, responder *Res

// unsubscribeGracefully Closes a NATS subscription gracefully, this includes
// draining, unsubscribing and ensuring that all callbacks are complete
func unsubscribeGracefully(c *nats.Subscription) error {
if c != nil {
func unsubscribeGracefully(s *nats.Subscription) error {
if s != nil {
// Drain NATS connections
err := c.Drain()
err := s.Drain()

if err != nil {
// If that fails, fall back to an unsubscribe
err = c.Unsubscribe()
err = s.Unsubscribe()

if err != nil {
return err
}
}

// Wait for all items to finish processing, including all callbacks
for {
messages, _, _ := c.Pending()

if messages > 0 {
time.Sleep(50 * time.Millisecond)
} else {
break
}
}
}

return nil
Expand Down
Loading

0 comments on commit 44c5c6a

Please sign in to comment.