Skip to content

Commit

Permalink
🧹 small code optimizations (#9)
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Hartmann <[email protected]>
  • Loading branch information
chris-rock authored Jul 26, 2022
1 parent 9882f59 commit 87378aa
Showing 1 changed file with 35 additions and 38 deletions.
73 changes: 35 additions & 38 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ const (
ContentTypeJson = "application/json"
)

var validContentTypes = []string{
ContentTypeProtobuf,
ContentTypeOctetProtobuf,
ContentTypeGrpcProtobuf,
ContentTypeJson,
var validContentTypes = map[string]struct{}{
ContentTypeProtobuf: {},
ContentTypeOctetProtobuf: {},
ContentTypeGrpcProtobuf: {},
ContentTypeJson: {},
}

// Method represents a RPC method and is used by protoc-gen-rangerrpc
Expand Down Expand Up @@ -57,8 +57,10 @@ func (s *server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()

contentType := req.Header.Get("Content-Type")

// verify content type
err := verifyContentType(req)
err := verifyContentType(req, contentType)
if err != nil {
httpError(w, req, err)
return
Expand All @@ -72,14 +74,14 @@ func (s *server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// extract the rpc method name and invoke the method
name := strings.TrimPrefix(req.URL.Path, s.prefix)

method := s.service.Methods[name]
if method == nil {
method, ok := s.service.Methods[name]
if !ok {
err := status.Error(codes.NotFound, "method not defined")
httpError(w, req, err)
return
}

rctx, rcancel, body, err := PreProcessRequest(ctx, req)
rctx, rcancel, body, err := preProcessRequest(ctx, req)
if err != nil {
httpError(w, req, err)
return
Expand All @@ -88,16 +90,22 @@ func (s *server) ServeHTTP(w http.ResponseWriter, req *http.Request) {

// invoke method and send the response
resp, err := method(rctx, &body)
s.sendResponse(w, req, resp, err)
if err != nil {
httpError(w, req, err)
return
}
// check if the accept header is set, otherwise use the incoming content type
encodingType := determineResponseType(contentType, req.Header.Get("Accept"))
s.sendResponse(w, req, resp, encodingType)
}

// PreProcessRequest is used to preprocess the incoming request.
// preProcessRequest is used to preprocess the incoming request.
// It returns the context, a cancel function and the body of the request. The cancel function can be used to cancel
// the context. It also adds the http headers to the context.
func PreProcessRequest(ctx context.Context, req *http.Request) (context.Context, context.CancelFunc, []byte, error) {
func preProcessRequest(ctx context.Context, req *http.Request) (context.Context, context.CancelFunc, []byte, error) {
// read body content
body, err := ioutil.ReadAll(req.Body)
req.Body.Close()
defer req.Body.Close()
if err != nil {
return nil, nil, nil, status.Error(codes.DataLoss, "unrecoverable data loss or corruption")
}
Expand All @@ -111,16 +119,8 @@ func PreProcessRequest(ctx context.Context, req *http.Request) (context.Context,
return rctx, rcancel, body, nil
}

func (s *server) sendResponse(w http.ResponseWriter, req *http.Request, resp proto.Message, err error) {
if err != nil {
httpError(w, req, err)
return
}

// check if the accept header is set, otherwise use the incoming content type
accept := determineResponseType(req.Header.Get("Content-Type"), req.Header.Get("Accept"))
payload, contentType, err := convertProtoToPayload(resp, accept)

func (s *server) sendResponse(w http.ResponseWriter, req *http.Request, resp proto.Message, contentType string) {
payload, contentType, err := convertProtoToPayload(resp, contentType)
if err != nil {
httpError(w, req, status.Error(codes.Internal, "error encoding response"))
return
Expand All @@ -134,44 +134,41 @@ func (s *server) sendResponse(w http.ResponseWriter, req *http.Request, resp pro

// convertProtoToPayload converts a proto message to the approaptiate formatted payload.
// Depending on the accept header it will return the payload as marshalled protobuf or json.
func convertProtoToPayload(resp proto.Message, accept string) ([]byte, string, error) {
func convertProtoToPayload(resp proto.Message, contentType string) ([]byte, string, error) {
var err error
var payload []byte
contentType := accept
switch accept {
switch contentType {
case ContentTypeProtobuf, ContentTypeGrpcProtobuf, ContentTypeOctetProtobuf:
contentType = ContentTypeProtobuf
payload, err = proto.Marshal(resp)
// as default, we return json to be compatible with browsers, since they do not
// request as application/json as default
default:
contentType = ContentTypeJson
payload, err = jsonpb.MarshalOptions{UseProtoNames: true}.Marshal(resp)
}

return payload, contentType, err
}

// verifyContentType validates the content type of the request is known.
func verifyContentType(req *http.Request) error {
header := req.Header.Get("Content-Type")

func verifyContentType(req *http.Request, contentType string) error {
// we assume "application/protobuf" if no content-type is set
if len(header) == 0 {
if contentType == "" {
return nil
}

i := strings.Index(header, ";")
i := strings.Index(contentType, ";")
if i == -1 {
i = len(header)
i = len(contentType)
}

ct := strings.TrimSpace(strings.ToLower(header[:i]))
ct := strings.TrimSpace(strings.ToLower(contentType[:i]))

// check that the incoming request has a valid content type
for _, a := range validContentTypes {
if a == ct {
return nil
}
_, ok := validContentTypes[ct]
if ok {
return nil
}

// if we reached here, we have to handle an unexpected incoming type
Expand All @@ -181,7 +178,7 @@ func verifyContentType(req *http.Request) error {
// determineResponseType returns the content type based on the Content-Type and Accept header.
func determineResponseType(contenttype string, accept string) string {
// use provided content type if no accept header was provided
if len(accept) == 0 {
if accept == "" {
accept = contenttype
}

Expand Down

0 comments on commit 87378aa

Please sign in to comment.