diff --git a/db/db.go b/db/db.go index bb627dd2b..0a47719d0 100644 --- a/db/db.go +++ b/db/db.go @@ -22,6 +22,7 @@ type Database interface { hardware template workflow + WorkerWorkflow } type hardware interface { @@ -43,20 +44,24 @@ type template interface { type workflow interface { CreateWorkflow(ctx context.Context, wf Workflow, data string, id uuid.UUID) error - InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error - GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int32, error) - GetWorkflowsForWorker(id string) ([]string, error) GetWorkflow(ctx context.Context, id string) (Workflow, error) DeleteWorkflow(ctx context.Context, id string, state int32) error ListWorkflows(fn func(wf Workflow) error) error UpdateWorkflow(ctx context.Context, wf Workflow, state int32) error + InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, t time.Time) error + ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error +} + +// WorkerWorkflow is an interface for methods invoked by APIs that the worker calls. +type WorkerWorkflow interface { + InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error + GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) + GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) UpdateWorkflowState(ctx context.Context, wfContext *pb.WorkflowContext) error GetWorkflowContexts(ctx context.Context, wfID string) (*pb.WorkflowContext, error) GetWorkflowActions(ctx context.Context, wfID string) (*pb.WorkflowActionList, error) - InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, t time.Time) error - ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error } // TinkDB implements the Database interface. diff --git a/db/mock/mock.go b/db/mock/mock.go index 57cfb14a9..aa4196fb5 100644 --- a/db/mock/mock.go +++ b/db/mock/mock.go @@ -19,7 +19,7 @@ type DB struct { InsertIntoWfDataTableFunc func(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error GetWorkflowMetadataFunc func(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowDataVersionFunc func(ctx context.Context, workflowID string) (int32, error) - GetWorkflowsForWorkerFunc func(id string) ([]string, error) + GetWorkflowsForWorkerFunc func(ctx context.Context, id string) ([]string, error) GetWorkflowContextsFunc func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) GetWorkflowActionsFunc func(ctx context.Context, wfID string) (*pb.WorkflowActionList, error) UpdateWorkflowStateFunc func(ctx context.Context, wfContext *pb.WorkflowContext) error diff --git a/db/mock/workflow.go b/db/mock/workflow.go index 31b113f91..86004e010 100644 --- a/db/mock/workflow.go +++ b/db/mock/workflow.go @@ -35,8 +35,8 @@ func (d DB) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int3 } // GetWorkflowsForWorker : returns the list of workflows for a particular worker. -func (d DB) GetWorkflowsForWorker(id string) ([]string, error) { - return d.GetWorkflowsForWorkerFunc(id) +func (d DB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) { + return d.GetWorkflowsForWorkerFunc(ctx, id) } // GetWorkflow returns a workflow. diff --git a/db/workflow.go b/db/workflow.go index cc8556560..42d926568 100644 --- a/db/workflow.go +++ b/db/workflow.go @@ -302,8 +302,8 @@ func (d TinkDB) GetWorkflowDataVersion(ctx context.Context, workflowID string) ( } // GetWorkflowsForWorker : returns the list of workflows for a particular worker. -func (d TinkDB) GetWorkflowsForWorker(id string) ([]string, error) { - rows, err := d.instance.Query(` +func (d TinkDB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) { + rows, err := d.instance.QueryContext(ctx, ` SELECT workflow_id FROM workflow_worker_map WHERE diff --git a/grpc-server/tinkerbell.go b/grpc-server/tinkerbell.go index 798010ad6..8d3476608 100644 --- a/grpc-server/tinkerbell.go +++ b/grpc-server/tinkerbell.go @@ -30,16 +30,16 @@ const ( // GetWorkflowContexts implements tinkerbell.GetWorkflowContexts. func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.WorkflowService_GetWorkflowContextsServer) error { - wfs, err := getWorkflowsForWorker(s.db, req.WorkerId) + wfs, err := getWorkflowsForWorker(stream.Context(), s.db, req.WorkerId) if err != nil { return err } for _, wf := range wfs { - wfContext, err := s.db.GetWorkflowContexts(context.Background(), wf) + wfContext, err := s.db.GetWorkflowContexts(stream.Context(), wf) if err != nil { return status.Errorf(codes.Aborted, err.Error()) } - if isApplicableToSend(context.Background(), s.logger, wfContext, req.WorkerId, s.db) { + if isApplicableToSend(stream.Context(), s.logger, wfContext, req.WorkerId, s.db) { if err := stream.Send(wfContext); err != nil { return err } @@ -50,7 +50,7 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W // GetWorkflowContextList implements tinkerbell.GetWorkflowContextList. func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) { - wfs, err := getWorkflowsForWorker(s.db, req.WorkerId) + wfs, err := getWorkflowsForWorker(ctx, s.db, req.WorkerId) if err != nil { return nil, err } @@ -167,7 +167,7 @@ func (s *server) UpdateWorkflowData(ctx context.Context, req *pb.UpdateWorkflowD // GetWorkflowData gets the ephemeral data for a workflow. func (s *server) GetWorkflowData(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { - if wfID := req.GetWorkflowId(); wfID == "" { + if id := req.GetWorkflowId(); id == "" { return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) } @@ -196,11 +196,11 @@ func (s *server) GetWorkflowDataVersion(ctx context.Context, req *pb.GetWorkflow return &pb.GetWorkflowDataResponse{Version: version}, nil } -func getWorkflowsForWorker(d db.Database, id string) ([]string, error) { +func getWorkflowsForWorker(ctx context.Context, d db.Database, id string) ([]string, error) { if id == "" { return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkerID) } - wfs, err := d.GetWorkflowsForWorker(id) + wfs, err := d.GetWorkflowsForWorker(ctx, id) if err != nil { return nil, status.Errorf(codes.Aborted, err.Error()) } diff --git a/grpc-server/tinkerbell_test.go b/grpc-server/tinkerbell_test.go index 6de6c8512..1c994042c 100644 --- a/grpc-server/tinkerbell_test.go +++ b/grpc-server/tinkerbell_test.go @@ -72,7 +72,7 @@ func TestGetWorkflowContextList(t *testing.T) { "database failure": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return []string{workflowID}, nil }, GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) { @@ -88,7 +88,7 @@ func TestGetWorkflowContextList(t *testing.T) { "no workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return nil, nil }, GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) { @@ -104,7 +104,7 @@ func TestGetWorkflowContextList(t *testing.T) { "workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return []string{workflowID}, nil }, GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) { @@ -763,7 +763,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { "database failure": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return nil, errors.New("database failed") }, }, @@ -776,7 +776,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { "no workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return nil, nil }, }, @@ -789,7 +789,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { "workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return []string{workflowID}, nil }, }, @@ -804,7 +804,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { s := testServer(t, tc.args.db) - res, err := getWorkflowsForWorker(s.db, tc.args.workerID) + res, err := getWorkflowsForWorker(context.Background(), s.db, tc.args.workerID) if err != nil { assert.True(t, tc.want.expectedError) assert.Error(t, err)