Skip to content

Commit

Permalink
Add stream interceptors (#190)
Browse files Browse the repository at this point in the history
* Add stream interceptors

* Add UTs, fix subject field handling

* Fix tests

* Improve subject value handling

* Remove redundant ctx returns, fix populating fields into cxt logger

* Add notes to generate mocks

* Add blank line
  • Loading branch information
addudko authored May 19, 2020
1 parent d048f15 commit c488b8e
Show file tree
Hide file tree
Showing 7 changed files with 1,017 additions and 30 deletions.
4 changes: 3 additions & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ check-fmt:

.PHONY: gen
gen: .gen-query .gen-errdetails .gen-errfields

.PHONY: mocks
mocks:
GO111MODULE=off go get -u github.com/maxbrunsfeld/counterfeiter
counterfeiter --fake-name ServerStreamMock -o ./logging/mocks/server_stream.go $(GOPATH)/src/github.com/infobloxopen/atlas-app-toolkit/vendor/google.golang.org/grpc/stream.go ServerStream
7 changes: 7 additions & 0 deletions logging/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,10 @@ For example:
## Other functions

The helper function `CopyLoggerWithLevel` can be used to make a deep copy of a logger at a new level, or using `CopyLoggerWithLevel(entry.Logger, level).WithFields(entry.Data)` can copy a logrus.Entry.

## Generate mocks

Mocks generated with this [tool](https://github.com/maxbrunsfeld/counterfeiter). Generate mocks for logging tests via:
```makefile
make mocks
```
140 changes: 111 additions & 29 deletions logging/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"strings"
"time"

grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/infobloxopen/atlas-app-toolkit/auth"
Expand Down Expand Up @@ -62,14 +62,14 @@ func LogLevelInterceptor(defaultLevel logrus.Level) grpc.UnaryServerInterceptor
}
}

func UnaryClientInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryClientInterceptor {
func UnaryClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryClientInterceptor {
options := initOptions(opts)

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
startTime := time.Now()
fields := newLoggerFields(method, startTime, DefaultClientKindValue)

ctx = setInterceptorFields(ctx, fields, logger, options, startTime)
setInterceptorFields(ctx, fields, entry.Logger, options, startTime)

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
Expand All @@ -80,23 +80,50 @@ func UnaryClientInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryCli
fields[DefaultGRPCCodeKey] = code.String()

levelLogf(
logrus.NewEntry(logger).WithFields(fields),
entry.WithFields(fields),
options.codeToLevel(code),
"finished unary call with code "+code.String())
"finished unary call with code %s", code.String())

return err
}
}

func UnaryServerInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryServerInterceptor {
func StreamClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamClientInterceptor {
options := initOptions(opts)

return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, option ...grpc.CallOption) (grpc.ClientStream, error) {
startTime := time.Now()
fields := newLoggerFields(method, startTime, DefaultClientKindValue)

setInterceptorFields(ctx, fields, entry.Logger, options, startTime)

clientStream, err := streamer(ctx, desc, cc, method, option...)
if err != nil {
fields[logrus.ErrorKey] = err
}

code := status.Code(err)
fields[DefaultGRPCCodeKey] = code.String()

levelLogf(
entry.WithFields(fields),
options.codeToLevel(code),
"finished client streaming call with code %s", code.String())

return clientStream, err
}
}

func UnaryServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryServerInterceptor {
options := initOptions(opts)

return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
startTime := time.Now()
fields := newLoggerFields(info.FullMethod, startTime, DefaultServerKindValue)
newCtx := newLoggerForCall(ctx, logrus.NewEntry(logger), fields)

newCtx = setInterceptorFields(newCtx, fields, logger, options, startTime)
setInterceptorFields(ctx, fields, entry.Logger, options, startTime)

newCtx := newLoggerForCall(ctx, entry, fields)

resp, err := handler(newCtx, req)
if err != nil {
Expand All @@ -109,92 +136,147 @@ func UnaryServerInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnarySer
levelLogf(
ctxlogrus.Extract(newCtx).WithFields(fields),
options.codeToLevel(code),
"finished unary call with code "+code.String())
"finished unary call with code %s", code.String())

return resp, err
}
}

func setInterceptorFields(ctx context.Context, fields logrus.Fields, logger *logrus.Logger, options *options, start time.Time) context.Context {
func StreamServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamServerInterceptor {
options := initOptions(opts)

return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
startTime := time.Now()
fields := newLoggerFields(info.FullMethod, startTime, DefaultServerKindValue)

setInterceptorFields(stream.Context(), fields, entry.Logger, options, startTime)

newCtx := newLoggerForCall(stream.Context(), entry, fields)

wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = newCtx

err := handler(srv, wrapped)
if err != nil {
fields[logrus.ErrorKey] = err
}

code := status.Code(err)
fields[DefaultGRPCCodeKey] = code.String()

levelLogf(
ctxlogrus.Extract(newCtx).WithFields(fields),
options.codeToLevel(code),
"finished server streaming call with code %s", code.String())

return err
}
}

func setInterceptorFields(ctx context.Context, fields logrus.Fields, logger *logrus.Logger, options *options, start time.Time) {
// In latest versions of Go use
// https://golang.org/src/time/time.go?s=25178:25216#L780
duration := int64(time.Since(start) / 1e6)
fields[DefaultDurationKey] = duration

ctx, err := addRequestIDField(ctx, fields)
err := addRequestIDField(ctx, fields)
if err != nil {
logger.Warn(err)
}

ctx, err = addAccountIDField(ctx, fields)
err = addAccountIDField(ctx, fields)
if err != nil {
logger.Warn(err)
}

ctx, err = addCustomField(ctx, fields, DefaultSubjectKey)
err = addCustomField(ctx, fields, DefaultSubjectKey)
if err != nil {
logger.Warn(err)
}

for _, v := range options.fields {
ctx, err = addCustomField(ctx, fields, v)
err = addCustomField(ctx, fields, v)
if err != nil {
logger.Warn(err)
}
}

for _, v := range options.headers {
ctx, err = addHeaderField(ctx, fields, v)
err = addHeaderField(ctx, fields, v)
if err != nil {
logger.Warn(err)
}
}

return ctx
}

func addRequestIDField(ctx context.Context, fields logrus.Fields) (context.Context, error) {
func addRequestIDField(ctx context.Context, fields logrus.Fields) error {
reqID, exists := requestid.FromContext(ctx)
if !exists || reqID == "" {
return ctx, fmt.Errorf("Unable to get %q from context", DefaultRequestIDKey)
return fmt.Errorf("Unable to get %q from context", DefaultRequestIDKey)
}

fields[DefaultRequestIDKey] = reqID

return metadata.AppendToOutgoingContext(ctx, DefaultRequestIDKey, reqID), nil
return nil
}

func addAccountIDField(ctx context.Context, fields logrus.Fields) (context.Context, error) {
func addAccountIDField(ctx context.Context, fields logrus.Fields) error {
accountID, err := auth.GetAccountID(ctx, nil)
if err != nil {
return ctx, fmt.Errorf("Unable to get %q from context", DefaultAccountIDKey)
return fmt.Errorf("Unable to get %q from context", DefaultAccountIDKey)
}

fields[DefaultAccountIDKey] = accountID

return metadata.AppendToOutgoingContext(ctx, DefaultAccountIDKey, accountID), err
return err
}

func addCustomField(ctx context.Context, fields logrus.Fields, customField string) (context.Context, error) {
func addCustomField(ctx context.Context, fields logrus.Fields, customField string) error {
field, err := auth.GetJWTField(ctx, customField, nil)
if err != nil {
return ctx, fmt.Errorf("Unable to get custom %q field from context", customField)
return fmt.Errorf("Unable to get custom %q field from context", customField)
}

// In case of subject field is a map
if customField == DefaultSubjectKey {

replacer := strings.NewReplacer("map[", "", "]", "")
field = replacer.Replace(field)
inner := strings.Split(field, " ")

m := map[string]interface{}{}

for _, v := range inner {
kv := strings.Split(v, ":")

if len(kv) == 1 {
fields[customField] = kv[0]

return err
}

m[kv[0]] = kv[1]
}

fields[customField] = m

return err
}

fields[customField] = field

return metadata.AppendToOutgoingContext(ctx, customField, field), err
return err
}

func addHeaderField(ctx context.Context, fields logrus.Fields, header string) (context.Context, error) {
func addHeaderField(ctx context.Context, fields logrus.Fields, header string) error {
field, ok := gateway.Header(ctx, header)
if !ok {
return ctx, fmt.Errorf("Unable to get custom header %q from context", header)
return fmt.Errorf("Unable to get custom header %q from context", header)
}

fields[strings.ToLower(header)] = field

return metadata.AppendToOutgoingContext(ctx, header, field), nil
return nil
}

func newLoggerFields(fullMethodString string, start time.Time, kind string) logrus.Fields {
Expand Down
Loading

0 comments on commit c488b8e

Please sign in to comment.