diff --git a/client.go b/client.go index c9496c8..d764506 100644 --- a/client.go +++ b/client.go @@ -124,15 +124,11 @@ func OpenTracingStreamClientInterceptor(tracer opentracing.Tracer, optFuncs ...O clientSpan.Finish() return cs, err } - return newOpenTracingClientStream(cs, method, desc, clientSpan, otgrpcOpts), nil + return newOpenTracingClientStream(ctx, cs, method, desc, clientSpan, otgrpcOpts), nil } } -func newOpenTracingClientStream(cs grpc.ClientStream, method string, desc *grpc.StreamDesc, clientSpan opentracing.Span, otgrpcOpts *options) grpc.ClientStream { - // Grab the client stream context because when the finish function or the goroutine below will be - // executed it's not guaranteed cs.Context() will be valid. - csCtx := cs.Context() - +func newOpenTracingClientStream(ctx context.Context, cs grpc.ClientStream, method string, desc *grpc.StreamDesc, clientSpan opentracing.Span, otgrpcOpts *options) grpc.ClientStream { finishChan := make(chan struct{}) isFinished := new(int32) @@ -152,7 +148,7 @@ func newOpenTracingClientStream(cs grpc.ClientStream, method string, desc *grpc. SetSpanTags(clientSpan, err, true) } if otgrpcOpts.decorator != nil { - otgrpcOpts.decorator(csCtx, clientSpan, method, nil, nil, err) + otgrpcOpts.decorator(ctx, clientSpan, method, nil, nil, err) } } go func() { @@ -160,8 +156,15 @@ func newOpenTracingClientStream(cs grpc.ClientStream, method string, desc *grpc. case <-finishChan: // The client span is being finished by another code path; hence, no // action is necessary. - case <-csCtx.Done(): - finishFunc(csCtx.Err()) + case <-ctx.Done(): + // Why use ctx rather than cs.Context()? Two reasons: + // 1. According to its docs, cs.Context() should not be used until after the first Header() or + // RecvMsg() call has returned. + // 2. ClientStream implementations cancel their context as soon as an error is received in Header(), + // RecvMsg(), SendMsg() or CloseSend(). This causes a race between the interceptor logging the + // returned error and this method logging the cancelled context. + // Using ctx avoids both of these issues. + finishFunc(ctx.Err()) } }() otcs := &openTracingClientStream{ diff --git a/test/interceptor_test.go b/test/interceptor_test.go index 4626b68..7d28981 100644 --- a/test/interceptor_test.go +++ b/test/interceptor_test.go @@ -132,9 +132,11 @@ func assertChildParentSpans(t *testing.T, tracer *mocktracer.MockTracer) { if len(spans) != 2 { t.Fatalf("Incorrect span length") } - parent := spans[1] - child := spans[0] - assert.Equal(t, child.ParentID, parent.Context().(mocktracer.MockSpanContext).SpanID) + clientSpan := spans[1] + serverSpan := spans[0] + assert.Equal(t, serverSpan.ParentID, clientSpan.Context().(mocktracer.MockSpanContext).SpanID) + assert.Nil(t, clientSpan.Tag("error")) + assert.Empty(t, clientSpan.Logs()) } func TestUnaryOpenTracing(t *testing.T) {