From ffe557cd14547ced21d60746d75c929f54918ba9 Mon Sep 17 00:00:00 2001 From: David Schmitt Date: Mon, 30 Oct 2023 12:35:18 +0100 Subject: [PATCH] Example use of sdpws client --- cmd/request.go | 416 +++++++++++++++++++++++++------------------------ 1 file changed, 211 insertions(+), 205 deletions(-) diff --git a/cmd/request.go b/cmd/request.go index 9f6dcf6e..53cfdcea 100644 --- a/cmd/request.go +++ b/cmd/request.go @@ -3,7 +3,6 @@ package cmd import ( "context" "encoding/json" - "errors" "fmt" "os" "os/signal" @@ -14,16 +13,14 @@ import ( "github.com/google/uuid" "github.com/overmindtech/ovm-cli/tracing" "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdp-go/sdpws" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/timestamppb" - "nhooyr.io/websocket" - "nhooyr.io/websocket/wspb" ) // requestCmd represents the start command @@ -108,21 +105,16 @@ func Request(ctx context.Context, ready chan bool) int { return 1 } - options := &websocket.DialOptions{ - HTTPClient: NewAuthenticatedClient(ctx, otelhttp.DefaultClient), - } - - // nolint: bodyclose // nhooyr.io/websocket reads the body internally - c, _, err := websocket.Dial(ctx, gatewayUrl, options) + c, err := sdpws.Dial(ctx, gatewayUrl, + NewAuthenticatedClient(ctx, otelhttp.DefaultClient), + &sdpws.LoggingGatewayMessageHandler{Level: log.InfoLevel}, + ) if err != nil { lf["gateway-url"] = gatewayUrl log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to connect to overmind API") return 1 } - defer c.Close(websocket.StatusGoingAway, "") - - // the default, 32kB is too small for cert bundles and rds-db-cluster-parameter-groups - c.SetReadLimit(2 * 1024 * 1024) + defer c.Close(ctx) // Log the request in JSON b, err := json.MarshalIndent(req, "", " ") @@ -132,186 +124,191 @@ func Request(ctx context.Context, ready chan bool) int { } log.WithContext(ctx).WithFields(lf).Infof("Request:\n%v", string(b)) - - err = wspb.Write(ctx, c, req) + q, err := createQuery() if err != nil { - log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to send request") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to create query") return 1 } - - queriesSent := true - - responses := make(chan *sdp.GatewayResponse) - - // Start a goroutine that reads responses - go func() { - for { - res := new(sdp.GatewayResponse) - - err = wspb.Read(ctx, c, res) - - if err != nil { - var e websocket.CloseError - if errors.As(err, &e) { - log.WithContext(ctx).WithFields(log.Fields{ - "code": e.Code.String(), - "reason": e.Reason, - }).Info("Websocket closing") - return - } - log.WithContext(ctx).WithFields(log.Fields{ - "error": err, - }).Error("Failed to read response") - return - } - - responses <- res - } - }() - - activeQueries := make(map[uuid.UUID]bool) - - var numItems, numEdges int - - // Read the responses -responses: - for { - select { - case <-ctx.Done(): - log.WithContext(ctx).WithFields(lf).Info("Context cancelled, exiting") - return 1 - - case resp := <-responses: - switch resp.ResponseType.(type) { - case *sdp.GatewayResponse_Status: - status := resp.GetStatus() - statusFields := log.Fields{ - "summary": status.Summary, - "responders": status.Summary.Responders, - "queriesSent": queriesSent, - "postProcessingComplete": status.PostProcessingComplete, - "itemsReceived": numItems, - "edgesReceived": numEdges, - } - - if status.Done() { - // fall through from all "final" query states, check if there's still queries in progress; - // only break from the loop if all queries have already been sent - // TODO: see above, still needs DefaultStartTimeout implemented to account for slow sources - allDone := allDone(ctx, activeQueries, lf) - statusFields["allDone"] = allDone - if allDone && queriesSent { - log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders and queries done") - break responses - } else { - log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders done, with unfinished queries") - } - } else { - log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("still waiting for responders") - } - - case *sdp.GatewayResponse_QueryStatus: - status := resp.GetQueryStatus() - statusFields := log.Fields{ - "status": status.Status.String(), - } - queryUuid := status.GetUUIDParsed() - if queryUuid == nil { - log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("Received QueryStatus with nil UUID") - continue responses - } - statusFields["query"] = queryUuid - - switch status.Status { - case sdp.QueryStatus_UNSPECIFIED: - statusFields["unexpected_status"] = true - case sdp.QueryStatus_STARTED: - activeQueries[*queryUuid] = true - case sdp.QueryStatus_FINISHED: - activeQueries[*queryUuid] = false - case sdp.QueryStatus_ERRORED: - activeQueries[*queryUuid] = false - case sdp.QueryStatus_CANCELLED: - activeQueries[*queryUuid] = false - default: - statusFields["unexpected_status"] = true - } - - log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("query status update") - - case *sdp.GatewayResponse_NewItem: - item := resp.GetNewItem() - numItems += 1 - log.WithContext(ctx).WithFields(lf).WithField("item", item.GloballyUniqueName()).Infof("new item") - - case *sdp.GatewayResponse_NewEdge: - edge := resp.GetNewEdge() - numEdges += 1 - log.WithContext(ctx).WithFields(lf).WithFields(log.Fields{ - "from": edge.From.GloballyUniqueName(), - "to": edge.To.GloballyUniqueName(), - }).Info("new edge") - - case *sdp.GatewayResponse_QueryError: - err := resp.GetQueryError() - log.WithContext(ctx).WithFields(lf).Errorf("Error from %v(%v): %v", err.ResponderName, err.SourceName, err) - - case *sdp.GatewayResponse_Error: - err := resp.GetError() - log.WithContext(ctx).WithFields(lf).Errorf("generic error: %v", err) - - default: - j := protojson.Format(resp) - log.WithContext(ctx).WithFields(lf).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) - } - } - } - - if viper.GetBool("snapshot-after") { - log.WithContext(ctx).Info("Starting snapshot") - msgId := uuid.New() - snapReq := &sdp.GatewayRequest{ - MinStatusInterval: minStatusInterval, - RequestType: &sdp.GatewayRequest_StoreSnapshot{ - StoreSnapshot: &sdp.StoreSnapshot{ - Name: viper.GetString("snapshot-name"), - Description: viper.GetString("snapshot-description"), - MsgID: msgId[:], - }, - }, - } - err = wspb.Write(ctx, c, snapReq) - if err != nil { - log.WithContext(ctx).WithFields(log.Fields{ - "error": err, - }).Error("Failed to send snapshot request") - return 1 - } - - for { - select { - case <-ctx.Done(): - log.WithContext(ctx).Info("Context cancelled, exiting") - return 1 - case resp := <-responses: - switch resp.ResponseType.(type) { - case *sdp.GatewayResponse_SnapshotStoreResult: - result := resp.GetSnapshotStoreResult() - if result.Success { - log.WithContext(ctx).Infof("Snapshot stored successfully: %v", uuid.UUID(result.SnapshotID)) - return 0 - } - - log.WithContext(ctx).Errorf("Snapshot store failed: %v", result.ErrorMessage) - return 1 - default: - j := protojson.Format(resp) - - log.WithContext(ctx).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) - } - } - } + err = c.SendQuery(ctx, q) + if err != nil { + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to execute query") + return 1 } + log.WithContext(ctx).WithFields(lf).WithError(err).Info("received items") + + // queriesSent := true + + // responses := make(chan *sdp.GatewayResponse) + + // // Start a goroutine that reads responses + // go func() { + // for { + // res := new(sdp.GatewayResponse) + + // err = wspb.Read(ctx, c, res) + + // if err != nil { + // var e websocket.CloseError + // if errors.As(err, &e) { + // log.WithContext(ctx).WithFields(log.Fields{ + // "code": e.Code.String(), + // "reason": e.Reason, + // }).Info("Websocket closing") + // return + // } + // log.WithContext(ctx).WithFields(log.Fields{ + // "error": err, + // }).Error("Failed to read response") + // return + // } + + // responses <- res + // } + // }() + + // activeQueries := make(map[uuid.UUID]bool) + + // var numItems, numEdges int + + // // Read the responses + // responses: + // for { + // select { + // case <-ctx.Done(): + // log.WithContext(ctx).WithFields(lf).Info("Context cancelled, exiting") + // return 1 + + // case resp := <-responses: + // switch resp.ResponseType.(type) { + // case *sdp.GatewayResponse_Status: + // status := resp.GetStatus() + // statusFields := log.Fields{ + // "summary": status.Summary, + // "responders": status.Summary.Responders, + // "queriesSent": queriesSent, + // "postProcessingComplete": status.PostProcessingComplete, + // "itemsReceived": numItems, + // "edgesReceived": numEdges, + // } + + // if status.Done() { + // // fall through from all "final" query states, check if there's still queries in progress; + // // only break from the loop if all queries have already been sent + // // TODO: see above, still needs DefaultStartTimeout implemented to account for slow sources + // allDone := allDone(ctx, activeQueries, lf) + // statusFields["allDone"] = allDone + // if allDone && queriesSent { + // log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders and queries done") + // break responses + // } else { + // log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders done, with unfinished queries") + // } + // } else { + // log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("still waiting for responders") + // } + + // case *sdp.GatewayResponse_QueryStatus: + // status := resp.GetQueryStatus() + // statusFields := log.Fields{ + // "status": status.Status.String(), + // } + // queryUuid := status.GetUUIDParsed() + // if queryUuid == nil { + // log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("Received QueryStatus with nil UUID") + // continue responses + // } + // statusFields["query"] = queryUuid + + // switch status.Status { + // case sdp.QueryStatus_UNSPECIFIED: + // statusFields["unexpected_status"] = true + // case sdp.QueryStatus_STARTED: + // activeQueries[*queryUuid] = true + // case sdp.QueryStatus_FINISHED: + // activeQueries[*queryUuid] = false + // case sdp.QueryStatus_ERRORED: + // activeQueries[*queryUuid] = false + // case sdp.QueryStatus_CANCELLED: + // activeQueries[*queryUuid] = false + // default: + // statusFields["unexpected_status"] = true + // } + + // log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("query status update") + + // case *sdp.GatewayResponse_NewItem: + // item := resp.GetNewItem() + // numItems += 1 + // log.WithContext(ctx).WithFields(lf).WithField("item", item.GloballyUniqueName()).Infof("new item") + + // case *sdp.GatewayResponse_NewEdge: + // edge := resp.GetNewEdge() + // numEdges += 1 + // log.WithContext(ctx).WithFields(lf).WithFields(log.Fields{ + // "from": edge.From.GloballyUniqueName(), + // "to": edge.To.GloballyUniqueName(), + // }).Info("new edge") + + // case *sdp.GatewayResponse_QueryError: + // err := resp.GetQueryError() + // log.WithContext(ctx).WithFields(lf).Errorf("Error from %v(%v): %v", err.ResponderName, err.SourceName, err) + + // case *sdp.GatewayResponse_Error: + // err := resp.GetError() + // log.WithContext(ctx).WithFields(lf).Errorf("generic error: %v", err) + + // default: + // j := protojson.Format(resp) + // log.WithContext(ctx).WithFields(lf).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) + // } + // } + // } + + // if viper.GetBool("snapshot-after") { + // log.WithContext(ctx).Info("Starting snapshot") + // msgId := uuid.New() + // snapReq := &sdp.GatewayRequest{ + // MinStatusInterval: minStatusInterval, + // RequestType: &sdp.GatewayRequest_StoreSnapshot{ + // StoreSnapshot: &sdp.StoreSnapshot{ + // Name: viper.GetString("snapshot-name"), + // Description: viper.GetString("snapshot-description"), + // MsgID: msgId[:], + // }, + // }, + // } + // err = wspb.Write(ctx, c, snapReq) + // if err != nil { + // log.WithContext(ctx).WithFields(log.Fields{ + // "error": err, + // }).Error("Failed to send snapshot request") + // return 1 + // } + + // for { + // select { + // case <-ctx.Done(): + // log.WithContext(ctx).Info("Context cancelled, exiting") + // return 1 + // case resp := <-responses: + // switch resp.ResponseType.(type) { + // case *sdp.GatewayResponse_SnapshotStoreResult: + // result := resp.GetSnapshotStoreResult() + // if result.Success { + // log.WithContext(ctx).Infof("Snapshot stored successfully: %v", uuid.UUID(result.SnapshotID)) + // return 0 + // } + + // log.WithContext(ctx).Errorf("Snapshot store failed: %v", result.ErrorMessage) + // return 1 + // default: + // j := protojson.Format(resp) + + // log.WithContext(ctx).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) + // } + // } + // } + // } return 0 } @@ -331,33 +328,42 @@ func methodFromString(method string) (sdp.QueryMethod, error) { return result, nil } +func createQuery() (*sdp.Query, error) { + u := uuid.New() + method, err := methodFromString(viper.GetString("query-method")) + if err != nil { + return nil, err + } + + return &sdp.Query{ + Method: method, + Type: viper.GetString("query-type"), + Query: viper.GetString("query"), + Scope: viper.GetString("query-scope"), + Deadline: timestamppb.New(time.Now().Add(10 * time.Hour)), + UUID: u[:], + RecursionBehaviour: &sdp.Query_RecursionBehaviour{ + LinkDepth: viper.GetUint32("link-depth"), + FollowOnlyBlastPropagation: viper.GetBool("blast-radius"), + }, + IgnoreCache: viper.GetBool("ignore-cache"), + }, nil +} + func createInitialRequest() (*sdp.GatewayRequest, error) { req := &sdp.GatewayRequest{ MinStatusInterval: minStatusInterval, } - u := uuid.New() switch viper.GetString("request-type") { case "query": - method, err := methodFromString(viper.GetString("query-method")) + q, err := createQuery() if err != nil { return nil, err } req.RequestType = &sdp.GatewayRequest_Query{ - Query: &sdp.Query{ - Method: method, - Type: viper.GetString("query-type"), - Query: viper.GetString("query"), - Scope: viper.GetString("query-scope"), - Deadline: timestamppb.New(time.Now().Add(10 * time.Hour)), - UUID: u[:], - RecursionBehaviour: &sdp.Query_RecursionBehaviour{ - LinkDepth: viper.GetUint32("link-depth"), - FollowOnlyBlastPropagation: viper.GetBool("blast-radius"), - }, - IgnoreCache: viper.GetBool("ignore-cache"), - }, + Query: q, } case "load-bookmark": bookmarkUUID, err := uuid.Parse(viper.GetString("bookmark-uuid"))