diff --git a/sdk/filter/multifilter_test.go b/sdk/filter/multifilter_test.go index a94b430..afdb557 100644 --- a/sdk/filter/multifilter_test.go +++ b/sdk/filter/multifilter_test.go @@ -11,11 +11,7 @@ import ( func TestMultiFilterEmpty(t *testing.T) { f := NewMultiFilter() - res := f.EvaluateURLAndHeaders(nil, "", nil) - assert.False(t, res.Block) - res = f.EvaluateBody(nil, nil, nil) - assert.False(t, res.Block) - res = f.Evaluate(nil) + res := f.Evaluate(nil) assert.False(t, res.Block) } @@ -26,65 +22,21 @@ func TestMultiFilterStopsAfterTrue(t *testing.T) { expectedFilterResult bool multiFilter *MultiFilter }{ - "URL and Headers multi filter": { - expectedURLAndHeadersFilterResult: true, - expectedBodyFilterResult: false, - expectedFilterResult: false, - multiFilter: NewMultiFilter( - mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { - return result.FilterResult{} - }, - }, - mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { - return result.FilterResult{Block: true, ResponseStatusCode: 403} - }, - }, - mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { - assert.Fail(t, "should not be called") - return result.FilterResult{} - }, - }, - ), - }, - "Body multi filter": { - expectedBodyFilterResult: true, - multiFilter: NewMultiFilter( - mock.Filter{ - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { - return result.FilterResult{} - }, - }, - mock.Filter{ - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { - return result.FilterResult{Block: true, ResponseStatusCode: 403} - }, - }, - mock.Filter{ - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { - assert.Fail(t, "should not be called") - return result.FilterResult{} - }, - }, - ), - }, "Evaluate multi filter": { expectedFilterResult: true, multiFilter: NewMultiFilter( mock.Filter{ - Evaluator: func(span sdk.Span, url string, body []byte, headers map[string][]string) result.FilterResult { + Evaluator: func(span sdk.Span) result.FilterResult { return result.FilterResult{} }, }, mock.Filter{ - Evaluator: func(span sdk.Span, url string, body []byte, headers map[string][]string) result.FilterResult { + Evaluator: func(span sdk.Span) result.FilterResult { return result.FilterResult{Block: true, ResponseStatusCode: 403} }, }, mock.Filter{ - Evaluator: func(span sdk.Span, url string, body []byte, headers map[string][]string) result.FilterResult { + Evaluator: func(span sdk.Span) result.FilterResult { assert.Fail(t, "should not be called") return result.FilterResult{} }, @@ -95,11 +47,7 @@ func TestMultiFilterStopsAfterTrue(t *testing.T) { for name, tCase := range tCases { t.Run(name, func(t *testing.T) { - res := tCase.multiFilter.EvaluateURLAndHeaders(nil, "", nil) - assert.Equal(t, tCase.expectedURLAndHeadersFilterResult, res.Block) - res = tCase.multiFilter.EvaluateBody(nil, nil, nil) - assert.Equal(t, tCase.expectedBodyFilterResult, res.Block) - res = tCase.multiFilter.Evaluate(nil) + res := tCase.multiFilter.Evaluate(nil) assert.Equal(t, tCase.expectedFilterResult, res.Block) }) } diff --git a/sdk/filter/noop_test.go b/sdk/filter/noop_test.go index 43264b4..c2373fa 100644 --- a/sdk/filter/noop_test.go +++ b/sdk/filter/noop_test.go @@ -8,10 +8,6 @@ import ( func TestNoopFilter(t *testing.T) { f := NoopFilter{} - res := f.EvaluateURLAndHeaders(nil, "", nil) - assert.False(t, res.Block) - res = f.EvaluateBody(nil, nil, nil) - assert.False(t, res.Block) - res = f.Evaluate(nil) + res := f.Evaluate(nil) assert.False(t, res.Block) } diff --git a/sdk/instrumentation/google.golang.org/grpc/server.go b/sdk/instrumentation/google.golang.org/grpc/server.go index ac012c5..a3eeaed 100644 --- a/sdk/instrumentation/google.golang.org/grpc/server.go +++ b/sdk/instrumentation/google.golang.org/grpc/server.go @@ -95,14 +95,6 @@ func wrapHandler( if dataCaptureConfig.RpcMetadata.Request.Value { setAttributesFromRequestIncomingMetadata(ctx, span) - - // TODO: decide what should be passed as URL in GRPC - if !dataCaptureConfig.RpcBody.Request.Value { - filterResult := filter.Evaluate(span) - if filterResult.Block { - return nil, status.Error(StatusCode(int(filterResult.ResponseStatusCode)), StatusText(int(filterResult.ResponseStatusCode))) - } - } } reqBody, err := marshalMessageableJSON(req) @@ -110,10 +102,13 @@ func wrapHandler( len(reqBody) > 0 && err == nil { setTruncatedBodyAttribute("request", reqBody, int(dataCaptureConfig.BodyMaxSizeBytes.Value), span) - filterResult := filter.Evaluate(span) - if filterResult.Block { - return nil, status.Error(StatusCode(int(filterResult.ResponseStatusCode)), StatusText(int(filterResult.ResponseStatusCode))) - } + } + + // TODO: decide what should be passed as URL in GRPC + // single evaluation call to filter after capturing the configured parameters + filterResult := filter.Evaluate(span) + if filterResult.Block { + return nil, status.Error(StatusCode(int(filterResult.ResponseStatusCode)), StatusText(int(filterResult.ResponseStatusCode))) } res, err := delegateHandler(ctx, req) diff --git a/sdk/instrumentation/net/http/handler.go b/sdk/instrumentation/net/http/handler.go index d34dae7..a454d7a 100644 --- a/sdk/instrumentation/net/http/handler.go +++ b/sdk/instrumentation/net/http/handler.go @@ -3,7 +3,6 @@ package http // import "github.com/hypertrace/goagent/sdk/instrumentation/net/ht import ( "bytes" "io" - "io/ioutil" "net/http" config "github.com/hypertrace/agent-config/gen/go/v1" @@ -76,17 +75,10 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { SetAttributesFromHeaders("request", NewHeaderMapAccessor(r.Header), span) } - // run filters on headers - filterResult := h.filter.Evaluate(span) - if filterResult.Block { - w.WriteHeader(int(filterResult.ResponseStatusCode)) - return - } - // nil check for body is important as this block turns the body into another // object that isn't nil and that will leverage the "Observer effect". if r.Body != nil && h.dataCaptureConfig.HttpBody.Request.Value && ShouldRecordBodyOfContentType(headerMapAccessor{r.Header}) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { return } @@ -101,14 +93,14 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { isMultipartFormDataBody) } - // run body filters - filterResult := h.filter.Evaluate(span) - if filterResult.Block { - w.WriteHeader(int(filterResult.ResponseStatusCode)) - return - } + r.Body = io.NopCloser(bytes.NewBuffer(body)) + } - r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + // single evaluation call to filter after capturing the configured parameters + filterResult := h.filter.Evaluate(span) + if filterResult.Block { + w.WriteHeader(int(filterResult.ResponseStatusCode)) + return } // create http.ResponseWriter interceptor for tracking status code diff --git a/sdk/instrumentation/net/http/handler_test.go b/sdk/instrumentation/net/http/handler_test.go index 32b059c..36147eb 100644 --- a/sdk/instrumentation/net/http/handler_test.go +++ b/sdk/instrumentation/net/http/handler_test.go @@ -370,14 +370,10 @@ func TestServerRequestFilter(t *testing.T) { body: "haha", options: &Options{ Filter: mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { - assert.Equal(t, "http://localhost/foo", url) - assert.Equal(t, 1, len(headers)) - assert.Equal(t, []string{"application/json"}, headers["Content-Type"]) - return result.FilterResult{} - }, - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { - assert.Equal(t, []byte("haha"), body) + Evaluator: func(span sdk.Span) result.FilterResult { + assert.Equal(t, "http://localhost/foo", span.GetAttributes().GetValue("http.url")) + assert.Equal(t, "application/json", span.GetAttributes().GetValue("http.request.header.content-type")) + assert.Equal(t, "haha", span.GetAttributes().GetValue("http.request.body")) return result.FilterResult{} }, }, @@ -390,14 +386,12 @@ func TestServerRequestFilter(t *testing.T) { body: "haha", options: &Options{ Filter: mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { - assert.Equal(t, "http://localhost/foo", url) - assert.Equal(t, 1, len(headers)) - assert.Equal(t, []string{"multipart/form-data"}, headers["Content-Type"]) - return result.FilterResult{} - }, - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { - assert.Equal(t, []byte(base64.RawStdEncoding.EncodeToString([]byte("haha"))), body) + Evaluator: func(span sdk.Span) result.FilterResult { + assert.Equal(t, "http://localhost/foo", span.GetAttributes().GetValue("http.url")) + assert.Equal(t, "multipart/form-data", span.GetAttributes().GetValue("http.request.header.content-type")) + assert.Equal(t, base64.RawStdEncoding.EncodeToString([]byte("haha")), + span.GetAttributes().GetValue("http.request.body.base64")) + return result.FilterResult{} }, }, @@ -411,10 +405,11 @@ func TestServerRequestFilter(t *testing.T) { body: "haha", options: &Options{ Filter: mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { - assert.Equal(t, "http://localhost/foo", url) - assert.Equal(t, 1, len(headers)) - assert.Equal(t, []string{"multipart/form-data"}, headers["Content-Type"]) + Evaluator: func(span sdk.Span) result.FilterResult { + assert.Equal(t, "http://localhost/foo", span.GetAttributes().GetValue("http.url")) + assert.Equal(t, "multipart/form-data", span.GetAttributes().GetValue("http.request.header.content-type")) + assert.Nil(t, span.GetAttributes().GetValue("http.request.body")) + assert.Nil(t, span.GetAttributes().GetValue("http.request.body.base64")) return result.FilterResult{} }, }, @@ -425,7 +420,7 @@ func TestServerRequestFilter(t *testing.T) { url: "http://localhost/foo", options: &Options{ Filter: mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { + Evaluator: func(span sdk.Span) result.FilterResult { return result.FilterResult{Block: true, ResponseStatusCode: 403} }, }, @@ -436,7 +431,7 @@ func TestServerRequestFilter(t *testing.T) { url: "http://localhost/foo", options: &Options{ Filter: mock.Filter{ - URLAndHeadersEvaluator: func(span sdk.Span, url string, headers map[string][]string) result.FilterResult { + Evaluator: func(span sdk.Span) result.FilterResult { return result.FilterResult{Block: true, ResponseStatusCode: 403} }, }, @@ -450,7 +445,7 @@ func TestServerRequestFilter(t *testing.T) { body: "haha", options: &Options{ Filter: mock.Filter{ - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { + Evaluator: func(span sdk.Span) result.FilterResult { return result.FilterResult{Block: true, ResponseStatusCode: 403} }, }, @@ -501,8 +496,8 @@ func TestProcessingBodyIsTrimmed(t *testing.T) { wh, _ := WrapHandler(h, mock.SpanFromContext, &Options{ Filter: mock.Filter{ - BodyEvaluator: func(span sdk.Span, body []byte, headers map[string][]string) result.FilterResult { - assert.Equal(t, "{", string(body)) // body is truncated + Evaluator: func(span sdk.Span) result.FilterResult { + assert.Equal(t, "{", span.GetAttributes().GetValue("http.request.body")) // body is truncated return result.FilterResult{Block: true, ResponseStatusCode: 403} }, }, @@ -510,6 +505,7 @@ func TestProcessingBodyIsTrimmed(t *testing.T) { wh.dataCaptureConfig = emptyTestConfig wh.dataCaptureConfig.HttpBody.Request = config.Bool(true) wh.dataCaptureConfig.BodyMaxProcessingSizeBytes = config.Int32(int32(bodyMaxProcessingSizeBytes)) + wh.dataCaptureConfig.BodyMaxSizeBytes = config.Int32(int32(bodyMaxProcessingSizeBytes)) ih := &mockHandler{baseHandler: wh}