From 0877a4f91ccd7e0f707dad7220dbf07eb02adbc1 Mon Sep 17 00:00:00 2001 From: bthari Date: Tue, 15 Oct 2024 13:28:29 +0700 Subject: [PATCH] feat: add webhook call to other events (#612) # Description Previously on #601 I added webhook for model version endpoint related event, and in here the event will be expanded into a model, model endpoint, model version related event, as we also want to have an action (from other service) to be triggered if these events happen. # Modifications - created another package for webhook interface - add event for: - model created - model endpoint created/updated/deleted - model version created/updated/deleted - change previous event name of `on-model-version-*` to `on-version-endpoint-*` # Tests # Checklist - [x] Added PR label - [x] Added unit test, integration, and/or e2e tests - [x] Tested locally - [ ] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduces API changes # Notes for Reviewer The version of MLP used here has a validation bug (which is updated on [MLP#117](https://github.com/caraml-dev/mlp/pull/117)). The code could still work with workaround (e.g. set `FinalResponse: true` in one async webhook if user use _all_ async webhook, but it will be confusing for user since async webhook response is expected to not be used anywhere), ~~so preferably to merge this PR after updating the MLP version as dependencies.~~ (MLP version will be updated with the s3 PR) # Release Notes ```release-note add webhook event call if there's changes on model, model endpoint, model version ``` --- api/api/model_endpoints_api.go | 17 +++ api/api/models_api.go | 7 + api/api/router.go | 3 +- api/api/versions_api.go | 21 ++- api/api/versions_api_test.go | 122 +++++++++++++++++ api/cmd/api/main.go | 10 +- api/cmd/api/setup.go | 10 +- api/queue/work/model_service_deployment.go | 19 +-- .../work/model_service_deployment_test.go | 125 ++++++++++------- api/service/version_endpoint_service.go | 39 ++---- api/service/version_endpoint_service_test.go | 54 +++----- api/webhook/mocks/webhook.go | 56 ++++++++ api/webhook/request.go | 10 ++ api/webhook/webhook.go | 128 ++++++++++++++++++ api/webhook/webhook_test.go | 102 ++++++++++++++ api/webhooks/webhooks.go | 24 ---- 16 files changed, 581 insertions(+), 166 deletions(-) create mode 100644 api/webhook/mocks/webhook.go create mode 100644 api/webhook/request.go create mode 100644 api/webhook/webhook.go create mode 100644 api/webhook/webhook_test.go delete mode 100644 api/webhooks/webhooks.go diff --git a/api/api/model_endpoints_api.go b/api/api/model_endpoints_api.go index e3a54375c..d55dea2b0 100644 --- a/api/api/model_endpoints_api.go +++ b/api/api/model_endpoints_api.go @@ -19,7 +19,9 @@ import ( "fmt" "net/http" + "github.com/caraml-dev/merlin/log" "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/webhook" "gorm.io/gorm" ) @@ -124,6 +126,11 @@ func (c *ModelEndpointsController) CreateModelEndpoint(r *http.Request, vars map return InternalServerError(fmt.Sprintf("Error creating model endpoint: %v", err)) } + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelEndpointCreated, webhook.SetBody(endpoint)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, endpoint: %d, error: %v", webhook.OnModelEndpointCreated, model.Name, endpoint.ID, err) + } + // Success. Return endpoint as response. return Created(endpoint) } @@ -192,6 +199,11 @@ func (c *ModelEndpointsController) UpdateModelEndpoint(r *http.Request, vars map return InternalServerError(fmt.Sprintf("Error updating model endpoint: %v", err)) } + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelEndpointUpdated, webhook.SetBody(newEndpoint)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, error: %v", webhook.OnModelEndpointUpdated, model.Name, err) + } + return Ok(newEndpoint) } @@ -227,5 +239,10 @@ func (c *ModelEndpointsController) DeleteModelEndpoint(r *http.Request, vars map return InternalServerError(fmt.Sprintf("Error deleting model endpoint: %v", err)) } + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelEndpointDeleted, webhook.SetBody(modelEndpoint)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, error: %v", webhook.OnModelEndpointDeleted, model.Name, err) + } + return Ok(nil) } diff --git a/api/api/models_api.go b/api/api/models_api.go index d58af773b..bdb999472 100644 --- a/api/api/models_api.go +++ b/api/api/models_api.go @@ -20,8 +20,10 @@ import ( "net/http" "strconv" + "github.com/caraml-dev/merlin/webhook" "gorm.io/gorm" + "github.com/caraml-dev/merlin/log" "github.com/caraml-dev/merlin/mlflow" "github.com/caraml-dev/merlin/models" "github.com/caraml-dev/merlin/service" @@ -78,6 +80,11 @@ func (c *ModelsController) CreateModel(r *http.Request, vars map[string]string, return InternalServerError(fmt.Sprintf("Error saving model: %v", err)) } + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelCreated, webhook.SetBody(model)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, project: %d, model: %s, error: %v", webhook.OnModelCreated, model.ProjectID, model.Name, err) + } + return Created(model) } diff --git a/api/api/router.go b/api/api/router.go index eca6695f6..6fc8902d8 100644 --- a/api/api/router.go +++ b/api/api/router.go @@ -24,8 +24,8 @@ import ( "strings" "time" + webhook "github.com/caraml-dev/merlin/webhook" mlflowDelete "github.com/caraml-dev/mlp/api/pkg/client/mlflow" - "github.com/feast-dev/feast/sdk/go/protos/feast/core" "github.com/go-playground/validator/v10" "github.com/gorilla/mux" @@ -75,6 +75,7 @@ type AppContext struct { FeastCoreClient core.CoreServiceClient MlflowClient mlflow.Client + Webhook webhook.Client } // Handler handles the API requests and responses. diff --git a/api/api/versions_api.go b/api/api/versions_api.go index aec0fc455..14552f8a0 100644 --- a/api/api/versions_api.go +++ b/api/api/versions_api.go @@ -20,11 +20,12 @@ import ( "fmt" "net/http" - "gorm.io/gorm" - + "github.com/caraml-dev/merlin/log" "github.com/caraml-dev/merlin/models" "github.com/caraml-dev/merlin/service" "github.com/caraml-dev/merlin/utils" + "github.com/caraml-dev/merlin/webhook" + "gorm.io/gorm" ) const DEFAULT_PYTHON_VERSION = "3.8.*" @@ -79,6 +80,11 @@ func (c *VersionsController) PatchVersion(r *http.Request, vars map[string]strin return InternalServerError(fmt.Sprintf("Error patching model version: %v", err)) } + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelVersionUpdated, webhook.SetBody(v)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, version: %d, error: %v", webhook.OnModelVersionUpdated, v.ModelID, v.ID, err) + } + return Ok(patchedVersion) } @@ -148,6 +154,12 @@ func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]stri if err != nil { return InternalServerError(fmt.Sprintf("Failed to save version: %v", err)) } + + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelVersionCreated, webhook.SetBody(version)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, version: %d, error: %v", webhook.OnModelVersionCreated, version.ModelID, version.ID, err) + } + return Created(version) } @@ -211,6 +223,11 @@ func (c *VersionsController) DeleteVersion(r *http.Request, vars map[string]stri return InternalServerError(fmt.Sprintf("Delete model version failed: %s", err.Error())) } + // trigger webhook call + if err = c.Webhook.TriggerWebhooks(ctx, webhook.OnModelVersionDeleted, webhook.SetBody(version)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, version: %d, error: %v", webhook.OnModelVersionDeleted, version.ModelID, version.ID, err) + } + return Ok(versionID) } diff --git a/api/api/versions_api_test.go b/api/api/versions_api_test.go index d6edda955..2edaef88d 100644 --- a/api/api/versions_api_test.go +++ b/api/api/versions_api_test.go @@ -22,6 +22,8 @@ import ( "testing" "github.com/caraml-dev/merlin/service" + "github.com/caraml-dev/merlin/webhook" + webhookMock "github.com/caraml-dev/merlin/webhook/mocks" "github.com/google/uuid" "github.com/caraml-dev/merlin/config" @@ -290,6 +292,7 @@ func TestPatchVersion(t *testing.T) { requestBody interface{} vars map[string]string versionService func() *mocks.VersionsService + webhook func() *webhookMock.Client expected *Response }{ { @@ -356,6 +359,11 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionUpdated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusOK, data: &models.Version{ @@ -443,6 +451,11 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionUpdated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusOK, data: &models.Version{ @@ -526,6 +539,11 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionUpdated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusOK, data: &models.Version{ @@ -576,6 +594,9 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Error validating version: custom predictor image must be set"}, @@ -597,6 +618,9 @@ func TestPatchVersion(t *testing.T) { nil, gorm.ErrRecordNotFound) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusNotFound, data: Error{Message: "Model version not found: record not found"}, @@ -618,6 +642,9 @@ func TestPatchVersion(t *testing.T) { nil, fmt.Errorf("Error creating secret: db is down")) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Error getting model version: Error creating secret: db is down"}, @@ -651,6 +678,9 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Unable to parse request body"}, @@ -703,6 +733,9 @@ func TestPatchVersion(t *testing.T) { }, mock.Anything).Return(nil, fmt.Errorf("Error creating secret: db is down")) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Error patching model version: Error creating secret: db is down"}, @@ -831,6 +864,11 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionUpdated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusOK, data: &models.Version{ @@ -922,6 +960,9 @@ func TestPatchVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{ @@ -933,6 +974,7 @@ func TestPatchVersion(t *testing.T) { for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { versionSvc := tC.versionService() + webhook := tC.webhook() ctl := &VersionsController{ AppContext: &AppContext{ @@ -946,6 +988,7 @@ func TestPatchVersion(t *testing.T) { MonitoringBaseURL: "http://grafana", }, }, + Webhook: webhook, }, } resp := ctl.PatchVersion(&http.Request{}, tC.vars, tC.requestBody) @@ -962,6 +1005,7 @@ func TestCreateVersion(t *testing.T) { versionService func() *mocks.VersionsService mlflowClient func() *mlfmocks.Client modelsService func() *mocks.ModelsService + webhook func() *webhookMock.Client expected *Response }{ { @@ -1037,6 +1081,11 @@ func TestCreateVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionCreated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusCreated, data: &models.Version{ @@ -1115,6 +1164,9 @@ func TestCreateVersion(t *testing.T) { }, mock.Anything).Return(nil, fmt.Errorf("pq constraint violation")) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{ @@ -1147,6 +1199,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1177,6 +1232,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1207,6 +1265,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1237,6 +1298,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1267,6 +1331,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1297,6 +1364,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1327,6 +1397,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1357,6 +1430,9 @@ func TestCreateVersion(t *testing.T) { svc.On("Save", mock.Anything, &models.Version{}, mock.Anything).Return(&models.Version{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "Valid label key/values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between."}, @@ -1418,6 +1494,11 @@ func TestCreateVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionCreated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusCreated, data: &models.Version{ @@ -1551,6 +1632,11 @@ func TestCreateVersion(t *testing.T) { }, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionCreated, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusCreated, data: &models.Version{ @@ -1645,6 +1731,9 @@ func TestCreateVersion(t *testing.T) { svc := &mocks.VersionsService{} return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{ @@ -1658,6 +1747,7 @@ func TestCreateVersion(t *testing.T) { versionSvc := tC.versionService() modelsSvc := tC.modelsService() mlflowClient := tC.mlflowClient() + webhook := tC.webhook() ctl := &VersionsController{ AppContext: &AppContext{ @@ -1673,6 +1763,7 @@ func TestCreateVersion(t *testing.T) { }, MlflowClient: mlflowClient, ModelsService: modelsSvc, + Webhook: webhook, }, } resp := ctl.CreateVersion(&http.Request{}, tC.vars, &tC.body) @@ -1690,6 +1781,7 @@ func TestDeleteVersion(t *testing.T) { mlflowDeleteService func() *mlflowDeleteServiceMocks.Service predictionJobService func() *mocks.PredictionJobService endpointService func() *mocks.EndpointsService + webhook func() *webhookMock.Client expected *Response }{ { @@ -1752,6 +1844,11 @@ func TestDeleteVersion(t *testing.T) { svc.On("ListEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*models.VersionEndpoint{}, nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionDeleted, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusOK, data: models.ID(1), @@ -1813,6 +1910,11 @@ func TestDeleteVersion(t *testing.T) { svc.On("DeleteRun", mock.Anything, "runID1", mock.Anything, true).Return(nil) return svc }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnModelVersionDeleted, mock.Anything).Return(nil) + return w + }, expected: &Response{ code: http.StatusOK, data: models.ID(1), @@ -1887,6 +1989,9 @@ func TestDeleteVersion(t *testing.T) { svc.On("DeleteRun", mock.Anything, "runID1", mock.Anything, true).Return(nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "There are active endpoint that still using this model version"}, @@ -1963,6 +2068,9 @@ func TestDeleteVersion(t *testing.T) { svc.On("DeleteRun", mock.Anything, "runID1", mock.Anything, true).Return(nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Failed to delete endpoint: failed to delete endpoint"}, @@ -2037,6 +2145,9 @@ func TestDeleteVersion(t *testing.T) { svc := &mocks.EndpointsService{} return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusBadRequest, data: Error{Message: "There are active prediction job that still using this model version"}, @@ -2113,6 +2224,9 @@ func TestDeleteVersion(t *testing.T) { svc := &mocks.EndpointsService{} return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Failed stopping prediction job: failed to stop prediction job"}, @@ -2178,6 +2292,9 @@ func TestDeleteVersion(t *testing.T) { svc.On("ListEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*models.VersionEndpoint{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Delete mlflow run failed: failed to delete mlflow run"}, @@ -2243,6 +2360,9 @@ func TestDeleteVersion(t *testing.T) { svc.On("ListEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*models.VersionEndpoint{}, nil) return svc }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, expected: &Response{ code: http.StatusInternalServerError, data: Error{Message: "Delete model version failed: failed to delete model version"}, @@ -2256,6 +2376,7 @@ func TestDeleteVersion(t *testing.T) { mlflowDeleteSvc := tC.mlflowDeleteService predictionJobSvc := tC.predictionJobService endpointService := tC.endpointService + webhook := tC.webhook() ctl := &VersionsController{ AppContext: &AppContext{ @@ -2273,6 +2394,7 @@ func TestDeleteVersion(t *testing.T) { MlflowDeleteService: mlflowDeleteSvc(), PredictionJobService: predictionJobSvc(), EndpointsService: endpointService(), + Webhook: webhook, }, } resp := ctl.DeleteVersion(&http.Request{}, tC.vars, nil) diff --git a/api/cmd/api/main.go b/api/cmd/api/main.go index 47a9df818..124320cbd 100644 --- a/api/cmd/api/main.go +++ b/api/cmd/api/main.go @@ -29,7 +29,6 @@ import ( "github.com/caraml-dev/merlin/cluster/labeller" mlflowDelete "github.com/caraml-dev/mlp/api/pkg/client/mlflow" "github.com/caraml-dev/mlp/api/pkg/instrumentation/sentry" - webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/gorilla/mux" "github.com/heptiolabs/healthcheck" @@ -49,7 +48,7 @@ import ( "github.com/caraml-dev/merlin/service" "github.com/caraml-dev/merlin/storage" "github.com/caraml-dev/merlin/warden" - "github.com/caraml-dev/merlin/webhooks" + "github.com/caraml-dev/merlin/webhook" "github.com/caraml-dev/mlp/api/pkg/authz/enforcer" "github.com/caraml-dev/mlp/api/pkg/instrumentation/newrelic" ) @@ -268,11 +267,7 @@ func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dis log.Panicf("invalid deployment label prefix (%s): %s", cfg.DeploymentLabelPrefix, err) } - webhookClient, err := webhookManager.InitializeWebhooks(&cfg.WebhooksConfig, webhooks.WebhookEvents) - if err != nil { - log.Panicf("failed to initialize webhooks: %s", err) - } - + webhookClient := webhook.NewWebhook(&cfg.WebhooksConfig) webServiceBuilder, predJobBuilder, imageBuilderJanitor := initImageBuilder(cfg) observabilityPublisherStorage := storage.NewObservabilityPublisherStorage(db) @@ -367,6 +362,7 @@ func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dis FeastCoreClient: coreClient, MlflowClient: mlflowClient, + Webhook: webhookClient, } return deps{ apiContext: apiContext, diff --git a/api/cmd/api/setup.go b/api/cmd/api/setup.go index 180885343..26c5b501b 100644 --- a/api/cmd/api/setup.go +++ b/api/cmd/api/setup.go @@ -8,9 +8,9 @@ import ( gcs "cloud.google.com/go/storage" "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/client/clientset/versioned" + "github.com/caraml-dev/merlin/webhook" "github.com/caraml-dev/mlp/api/pkg/artifact" "github.com/caraml-dev/mlp/api/pkg/auth" - "github.com/caraml-dev/mlp/api/pkg/webhooks" feast "github.com/feast-dev/feast/sdk/go" "github.com/feast-dev/feast/sdk/go/protos/feast/core" "google.golang.org/grpc" @@ -422,7 +422,7 @@ func initPredictionJobService(cfg *config.Config, controllers map[string]batch.C return service.NewPredictionJobService(controllers, builder, predictionJobStorage, clock.RealClock{}, cfg.Environment, producer) } -func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, observabilityEvent event.EventProducer, webhookManager webhooks.WebhookManager) *work.ModelServiceDeployment { +func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, observabilityEvent event.EventProducer, webhook webhook.Client) *work.ModelServiceDeployment { return &work.ModelServiceDeployment{ ClusterControllers: controllers, ImageBuilder: builder, @@ -431,7 +431,7 @@ func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBu LoggerDestinationURL: cfg.LoggerDestinationURL, MLObsLoggerDestinationURL: cfg.MLObsLoggerDestinationURL, ObservabilityEventProducer: observabilityEvent, - WebhookManager: webhookManager, + Webhook: webhook, } } @@ -504,7 +504,7 @@ func initClusterControllers(cfg *config.Config) map[string]cluster.Controller { return controllers } -func initVersionEndpointService(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, feastCoreClient core.CoreServiceClient, producer queue.Producer, webhookManager webhooks.WebhookManager) service.EndpointsService { +func initVersionEndpointService(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, feastCoreClient core.CoreServiceClient, producer queue.Producer, webhook webhook.Client) service.EndpointsService { return service.NewEndpointService(service.EndpointServiceParams{ ClusterControllers: controllers, ImageBuilder: builder, @@ -516,7 +516,7 @@ func initVersionEndpointService(cfg *config.Config, builder imagebuilder.ImageBu JobProducer: producer, FeastCoreClient: feastCoreClient, StandardTransformerConfig: cfg.StandardTransformerConfig, - WebhookManager: webhookManager, + Webhook: webhook, }) } diff --git a/api/queue/work/model_service_deployment.go b/api/queue/work/model_service_deployment.go index a34b4776c..b169789a3 100644 --- a/api/queue/work/model_service_deployment.go +++ b/api/queue/work/model_service_deployment.go @@ -15,8 +15,7 @@ import ( "github.com/caraml-dev/merlin/pkg/observability/event" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/storage" - "github.com/caraml-dev/merlin/webhooks" - webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" + "github.com/caraml-dev/merlin/webhook" "github.com/prometheus/client_golang/prometheus" "gorm.io/gorm" ) @@ -44,7 +43,7 @@ type ModelServiceDeployment struct { LoggerDestinationURL string MLObsLoggerDestinationURL string ObservabilityEventProducer event.EventProducer - WebhookManager webhookManager.WebhookManager + Webhook webhook.Client } type EndpointJob struct { @@ -211,17 +210,9 @@ func (depl *ModelServiceDeployment) Deploy(job *queue.Job) error { } } - // calling webhooks if there's any webhooks configured - if depl.WebhookManager != nil && depl.WebhookManager.IsEventConfigured(webhooks.OnModelVersionDeployed) { - body := &webhooks.VersionEndpointRequest{ - EventType: webhooks.OnModelVersionDeployed, - VersionEndpoint: endpoint, - } - - err = depl.WebhookManager.InvokeWebhooks(ctx, webhooks.OnModelVersionDeployed, body, webhookManager.NoOpCallback, webhookManager.NoOpErrorHandler) - if err != nil { - log.Warnf("unable to invoke webhooks for event type: %s, model: %s, version: %s, error: %v", webhooks.OnModelVersionDeployed, model.Name, version.ID, err) - } + // trigger webhook call + if err = depl.Webhook.TriggerWebhooks(ctx, webhook.OnVersionEndpointDeployed, webhook.SetBody(endpoint)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, endpoint: %d, error: %v", webhook.OnVersionEndpointDeployed, endpoint.VersionModelID, endpoint.ID, err) } return nil diff --git a/api/queue/work/model_service_deployment_test.go b/api/queue/work/model_service_deployment_test.go index f3e95895b..e1201fb19 100644 --- a/api/queue/work/model_service_deployment_test.go +++ b/api/queue/work/model_service_deployment_test.go @@ -14,8 +14,8 @@ import ( eventMock "github.com/caraml-dev/merlin/pkg/observability/event/mocks" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/storage/mocks" - webhook "github.com/caraml-dev/merlin/webhooks" - webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" + "github.com/caraml-dev/merlin/webhook" + webhookMock "github.com/caraml-dev/merlin/webhook/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "gorm.io/gorm" @@ -74,7 +74,7 @@ func TestExecuteDeployment(t *testing.T) { storage func() *mocks.VersionEndpointStorage controller func() *clusterMock.Controller imageBuilder func() *imageBuilderMock.ImageBuilder - webhookManager func() webhookManager.WebhookManager + webhook func() *webhookMock.Client eventProducer *eventMock.EventProducer }{ { @@ -120,10 +120,10 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -170,11 +170,12 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, + eventProducer: func() *eventMock.EventProducer { producer := &eventMock.EventProducer{} producer.On("VersionEndpointChangeEvent", &models.VersionEndpoint{ @@ -192,7 +193,7 @@ func TestExecuteDeployment(t *testing.T) { }(), }, { - name: "Success: with calling webhooks", + name: "Success: with calling webhook", model: model, version: version, endpoint: &models.VersionEndpoint{ @@ -234,12 +235,10 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - manager := webhookManager.NewMockWebhookManager(t) - - manager.On("IsEventConfigured", mock.Anything).Return(true) - manager.On("InvokeWebhooks", mock.Anything, webhook.OnModelVersionDeployed, mock.IsType(&webhook.VersionEndpointRequest{}), mock.Anything, mock.Anything).Return(nil) - return manager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -286,11 +285,12 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, + eventProducer: func() *eventMock.EventProducer { producer := &eventMock.EventProducer{} producer.On("VersionEndpointChangeEvent", &models.VersionEndpoint{ @@ -358,10 +358,10 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -415,10 +415,10 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -464,10 +464,10 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -515,10 +515,10 @@ func TestExecuteDeployment(t *testing.T) { Return("gojek/mymodel-1:latest", nil) return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -566,10 +566,10 @@ func TestExecuteDeployment(t *testing.T) { Return("gojek/mymodel-1:latest", nil) return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -629,10 +629,10 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - webhookManager := webhookManager.NewMockWebhookManager(t) - webhookManager.On("IsEventConfigured", mock.Anything).Return(false) - return webhookManager + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w }, }, { @@ -672,8 +672,8 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - return webhookManager.NewMockWebhookManager(t) + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) }, }, { @@ -712,8 +712,8 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder.On("BuildImage", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("Failed to build image")) return mockImgBuilder }, - webhookManager: func() webhookManager.WebhookManager { - return webhookManager.NewMockWebhookManager(t) + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) }, }, } @@ -724,7 +724,7 @@ func TestExecuteDeployment(t *testing.T) { imgBuilder := tt.imageBuilder() mockStorage := tt.storage() mockDeploymentStorage := tt.deploymentStorage() - mockWebhook := tt.webhookManager() + mockWebhook := tt.webhook() job := &queue.Job{ Name: "job", Arguments: queue.Arguments{ @@ -743,7 +743,7 @@ func TestExecuteDeployment(t *testing.T) { DeploymentStorage: mockDeploymentStorage, LoggerDestinationURL: loggerDestinationURL, ObservabilityEventProducer: tt.eventProducer, - WebhookManager: mockWebhook, + Webhook: mockWebhook, } err := svc.Deploy(job) @@ -842,6 +842,7 @@ func TestExecuteRedeployment(t *testing.T) { storage func() *mocks.VersionEndpointStorage controller func() *clusterMock.Controller imageBuilder func() *imageBuilderMock.ImageBuilder + webhook func() *webhookMock.Client }{ { name: "Success: Redeploy running endpoint", @@ -914,6 +915,11 @@ func TestExecuteRedeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w + }, }, { name: "Success: Redeploy serving endpoint", @@ -986,6 +992,11 @@ func TestExecuteRedeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w + }, }, { name: "Success: Redeploy failed endpoint", @@ -1058,6 +1069,11 @@ func TestExecuteRedeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhook: func() *webhookMock.Client { + w := webhookMock.NewClient(t) + w.On("TriggerWebhooks", mock.Anything, webhook.OnVersionEndpointDeployed, mock.Anything).Return(nil) + return w + }, }, { name: "Failed to redeploy running endpoint", @@ -1120,6 +1136,9 @@ func TestExecuteRedeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhook: func() *webhookMock.Client { + return webhookMock.NewClient(t) + }, }, } for _, tt := range tests { @@ -1129,6 +1148,7 @@ func TestExecuteRedeployment(t *testing.T) { imgBuilder := tt.imageBuilder() mockStorage := tt.storage() mockDeploymentStorage := tt.deploymentStorage() + mockWebhook := tt.webhook() job := &queue.Job{ Name: "job", Arguments: queue.Arguments{ @@ -1146,6 +1166,7 @@ func TestExecuteRedeployment(t *testing.T) { Storage: mockStorage, DeploymentStorage: mockDeploymentStorage, LoggerDestinationURL: loggerDestinationURL, + Webhook: mockWebhook, } err := svc.Deploy(job) diff --git a/api/service/version_endpoint_service.go b/api/service/version_endpoint_service.go index bfd712232..eee58a608 100644 --- a/api/service/version_endpoint_service.go +++ b/api/service/version_endpoint_service.go @@ -33,8 +33,7 @@ import ( "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/queue/work" "github.com/caraml-dev/merlin/storage" - "github.com/caraml-dev/merlin/webhooks" - webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" + "github.com/caraml-dev/merlin/webhook" "github.com/feast-dev/feast/sdk/go/protos/feast/core" "github.com/google/uuid" "google.golang.org/protobuf/encoding/protojson" @@ -69,7 +68,7 @@ type EndpointServiceParams struct { JobProducer queue.Producer FeastCoreClient core.CoreServiceClient StandardTransformerConfig config.StandardTransformerConfig - WebhookManager webhookManager.WebhookManager + Webhook webhook.Client } type endpointService struct { @@ -84,7 +83,7 @@ type endpointService struct { jobProducer queue.Producer feastCoreClient core.CoreServiceClient standardTransformerConfig config.StandardTransformerConfig - webhookManager webhookManager.WebhookManager + webhook webhook.Client } func NewEndpointService(params EndpointServiceParams) EndpointsService { @@ -100,7 +99,7 @@ func NewEndpointService(params EndpointServiceParams) EndpointsService { jobProducer: params.JobProducer, feastCoreClient: params.FeastCoreClient, standardTransformerConfig: params.StandardTransformerConfig, - webhookManager: params.WebhookManager, + webhook: params.Webhook, } } @@ -138,18 +137,10 @@ func (k *endpointService) DeployEndpoint(ctx context.Context, environment *model return nil, err } - // calling webhooks if there's any webhooks configured - if k.webhookManager != nil && k.webhookManager.IsEventConfigured(webhooks.OnModelVersionPredeployment) { - body := &webhooks.VersionEndpointRequest{ - EventType: webhooks.OnModelVersionPredeployment, - VersionEndpoint: endpoint, - } - - err := k.webhookManager.InvokeWebhooks(ctx, webhooks.OnModelVersionPredeployment, body, webhookManager.NoOpCallback, webhookManager.NoOpErrorHandler) - if err != nil { - log.Errorf("unable to invoke webhooks for event type: %s, model: %s, version: %s, error: %v", webhooks.OnModelVersionPredeployment, model.Name, version.ID, err) - return nil, err - } + // calling webhook if there's any webhook configured + if err = k.webhook.TriggerWebhooks(ctx, webhook.OnVersionEndpointPredeployment, webhook.SetBody(endpoint)); err != nil { + log.Errorf("unable to invoke webhook for event type: %s, model: %s, endpoint: %d, error: %v", webhook.OnVersionEndpointPredeployment, endpoint.VersionModelID, endpoint.ID, err) + return nil, err } // Copy to avoid race condition @@ -301,17 +292,9 @@ func (k *endpointService) UndeployEndpoint(ctx context.Context, environment *mod return nil, err } - // calling webhooks if there's any webhooks configured - if k.webhookManager != nil && k.webhookManager.IsEventConfigured(webhooks.OnModelVersionUndeployed) { - body := &webhooks.VersionEndpointRequest{ - EventType: webhooks.OnModelVersionUndeployed, - VersionEndpoint: endpoint, - } - - err = k.webhookManager.InvokeWebhooks(ctx, webhooks.OnModelVersionUndeployed, body, webhookManager.NoOpCallback, webhookManager.NoOpErrorHandler) - if err != nil { - log.Warnf("unable to invoke webhooks for event type: %s, model: %s, version: %s, error: %v", webhooks.OnModelVersionUndeployed, model.Name, version.ID, err) - } + // calling webhook if there's any webhook configured + if err = k.webhook.TriggerWebhooks(ctx, webhook.OnVersionEndpointUndeployed, webhook.SetBody(endpoint)); err != nil { + log.Warnf("unable to invoke webhook for event type: %s, model: %s, endpoint: %d, error: %v", webhook.OnVersionEndpointUndeployed, endpoint.VersionModelID, endpoint.ID, err) } return endpoint, nil diff --git a/api/service/version_endpoint_service_test.go b/api/service/version_endpoint_service_test.go index 362021ae8..6f9dbf2b6 100644 --- a/api/service/version_endpoint_service_test.go +++ b/api/service/version_endpoint_service_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + webhookMock "github.com/caraml-dev/merlin/webhook/mocks" webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/feast-dev/feast/sdk/go/protos/feast/core" "github.com/feast-dev/feast/sdk/go/protos/feast/types" @@ -48,7 +49,7 @@ import ( "github.com/caraml-dev/merlin/pkg/transformer/spec" queueMock "github.com/caraml-dev/merlin/queue/mocks" "github.com/caraml-dev/merlin/storage/mocks" - webhooks "github.com/caraml-dev/merlin/webhooks" + webhooks "github.com/caraml-dev/merlin/webhook" ) var ( @@ -837,14 +838,11 @@ func TestDeployEndpoint(t *testing.T) { imgBuilder := &imageBuilderMock.ImageBuilder{} mockStorage := &mocks.VersionEndpointStorage{} mockDeploymentStorage := &mocks.DeploymentStorage{} - mockWebhook := webhookManager.NewMockWebhookManager(t) + mockWebhook := webhookMock.NewClient(t) mockStorage.On("Save", mock.Anything).Return(nil) mockDeploymentStorage.On("Save", mock.Anything).Return(nil, nil) - mockWebhook.On("IsEventConfigured", webhooks.OnModelVersionPredeployment).Return(tt.args.isWebhookExist) - if tt.args.isWebhookExist { - mockWebhook.On("InvokeWebhooks", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - } + mockWebhook.On("TriggerWebhooks", mock.Anything, webhooks.OnVersionEndpointPredeployment, mock.Anything).Return(nil) mockCfg := &config.Config{ Environment: "dev", @@ -866,7 +864,7 @@ func TestDeployEndpoint(t *testing.T) { MonitoringConfig: mockCfg.FeatureToggleConfig.MonitoringConfig, LoggerDestinationURL: loggerDestinationURL, JobProducer: mockQueueProducer, - WebhookManager: mockWebhook, + Webhook: mockWebhook, }) actualEndpoint, err := endpointSvc.DeployEndpoint(context.Background(), tt.args.environment, tt.args.model, tt.args.version, tt.args.endpoint) if tt.wantDeployError { @@ -898,10 +896,6 @@ func TestDeployEndpoint(t *testing.T) { if tt.args.endpoint.Transformer != nil { assert.Equal(t, tt.args.endpoint.Transformer.Enabled, actualEndpoint.Transformer.Enabled) } - - if tt.args.isWebhookExist { - mockWebhook.AssertNumberOfCalls(t, "InvokeWebhooks", 1) - } }) } } @@ -2146,6 +2140,12 @@ func TestDeployEndpoint_StandardTransformer(t *testing.T) { mockQueueProducer.On("EnqueueJob", mock.Anything).Return(nil) + mockWebhook := webhookMock.NewClient(t) + if tC.err == nil { + mockWebhook.On("TriggerWebhooks", mock.Anything, webhooks.OnVersionEndpointPredeployment, mock.Anything).Return(nil) + + } + imgBuilder := &imageBuilderMock.ImageBuilder{} mockStorage := &mocks.VersionEndpointStorage{} mockDeploymentStorage := &mocks.DeploymentStorage{} @@ -2185,6 +2185,7 @@ func TestDeployEndpoint_StandardTransformer(t *testing.T) { JobProducer: mockQueueProducer, StandardTransformerConfig: mockCfg.StandardTransformerConfig, FeastCoreClient: mockFeastCore, + Webhook: mockWebhook, }) createdEndpoint, err := endpointSvc.DeployEndpoint(context.Background(), tC.environment, tC.model, tC.version, tC.endpoint) if err != nil { @@ -2388,9 +2389,8 @@ func assertElementMatchFeatureTableMetadata(t *testing.T, expectation []*spec.Fe func TestUndeployEndpoint(t *testing.T) { type webhooksArgs struct { - event webhookManager.EventType - isEventConfigured bool - err error + event webhookManager.EventType + err error } type args struct { @@ -2432,9 +2432,8 @@ func TestUndeployEndpoint(t *testing.T) { Status: models.EndpointRunning, }, &webhooksArgs{ - event: webhooks.OnModelVersionUndeployed, - isEventConfigured: false, - err: nil, + event: webhooks.OnVersionEndpointUndeployed, + err: nil, }, }, expectedEndpoint: &models.VersionEndpoint{ @@ -2453,8 +2452,7 @@ func TestUndeployEndpoint(t *testing.T) { Status: models.EndpointRunning, }, &webhooksArgs{ - event: webhooks.OnModelVersionUndeployed, - isEventConfigured: true, + event: webhooks.OnVersionEndpointUndeployed, }, }, expectedEndpoint: &models.VersionEndpoint{ @@ -2472,8 +2470,7 @@ func TestUndeployEndpoint(t *testing.T) { Status: models.EndpointRunning, }, &webhooksArgs{ - event: webhooks.OnModelVersionUndeployed, - isEventConfigured: false, + event: webhooks.OnVersionEndpointUndeployed, }, }, expectedEndpoint: &models.VersionEndpoint{ @@ -2492,7 +2489,7 @@ func TestUndeployEndpoint(t *testing.T) { imgBuilder := &imageBuilderMock.ImageBuilder{} mockStorage := &mocks.VersionEndpointStorage{} mockDeploymentStorage := &mocks.DeploymentStorage{} - mockWebhook := webhookManager.NewMockWebhookManager(t) + mockWebhook := webhookMock.NewClient(t) envController.On("Delete", mock.Anything, mock.Anything).Return(nil, nil) mockStorage.On("Save", mock.Anything).Return(nil) @@ -2500,10 +2497,7 @@ func TestUndeployEndpoint(t *testing.T) { mockDeploymentStorage.On("Undeploy", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("Failed to undeploy")) } else { mockDeploymentStorage.On("Undeploy", mock.Anything, mock.Anything, mock.Anything).Return(nil) - mockWebhook.On("IsEventConfigured", mock.Anything).Return(tt.args.webhooks.isEventConfigured) - } - if tt.args.webhooks.isEventConfigured { - mockWebhook.On("InvokeWebhooks", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockWebhook.On("TriggerWebhooks", mock.Anything, webhooks.OnVersionEndpointUndeployed, mock.Anything).Return(nil) } mockCfg := &config.Config{ @@ -2526,7 +2520,7 @@ func TestUndeployEndpoint(t *testing.T) { MonitoringConfig: mockCfg.FeatureToggleConfig.MonitoringConfig, LoggerDestinationURL: loggerDestinationURL, JobProducer: mockQueueProducer, - WebhookManager: mockWebhook, + Webhook: mockWebhook, }) actualEndpoint, err := endpointSvc.UndeployEndpoint(context.Background(), env, model, version, tt.args.endpoint) @@ -2535,17 +2529,11 @@ func TestUndeployEndpoint(t *testing.T) { mockStorage.AssertNumberOfCalls(t, "Save", 1) mockDeploymentStorage.AssertNumberOfCalls(t, "Undeploy", 1) - if tt.args.webhooks.isEventConfigured { - mockWebhook.AssertNumberOfCalls(t, "InvokeWebhooks", 1) - } - if tt.wantUndeployError { assert.Error(t, err) return } - mockWebhook.AssertNumberOfCalls(t, "IsEventConfigured", 1) - assert.NoError(t, err) assert.Equal(t, tt.expectedEndpoint, actualEndpoint) }) diff --git a/api/webhook/mocks/webhook.go b/api/webhook/mocks/webhook.go new file mode 100644 index 000000000..c80a9c9da --- /dev/null +++ b/api/webhook/mocks/webhook.go @@ -0,0 +1,56 @@ +// Code generated by mockery v2.44.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + webhook "github.com/caraml-dev/merlin/webhook" + mock "github.com/stretchr/testify/mock" + + webhooks "github.com/caraml-dev/mlp/api/pkg/webhooks" +) + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// TriggerWebhooks provides a mock function with given fields: ctx, event, opts +func (_m *Client) TriggerWebhooks(ctx context.Context, event webhooks.EventType, opts ...webhook.Option) error { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, event) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for TriggerWebhooks") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, webhooks.EventType, ...webhook.Option) error); ok { + r0 = rf(ctx, event, opts...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewClient(t interface { + mock.TestingT + Cleanup(func()) +}) *Client { + mock := &Client{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/webhook/request.go b/api/webhook/request.go new file mode 100644 index 000000000..40859473e --- /dev/null +++ b/api/webhook/request.go @@ -0,0 +1,10 @@ +package webhook + +import ( + "github.com/caraml-dev/mlp/api/pkg/webhooks" +) + +type WebhookRequest struct { + EventType webhooks.EventType `json:"event_type"` + Data map[string]interface{} `json:"data"` +} diff --git a/api/webhook/webhook.go b/api/webhook/webhook.go new file mode 100644 index 000000000..055fa7f8c --- /dev/null +++ b/api/webhook/webhook.go @@ -0,0 +1,128 @@ +package webhook + +import ( + "context" + + "github.com/caraml-dev/merlin/log" + "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/mlp/api/pkg/webhooks" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" +) + +var ( + OnModelCreated = webhooks.EventType("on-model-created") + + OnModelEndpointCreated = webhooks.EventType("on-model-endpoint-created") + OnModelEndpointUpdated = webhooks.EventType("on-model-endpoint-updated") + OnModelEndpointDeleted = webhooks.EventType("on-model-endpoint-deleted") + + OnModelVersionCreated = webhooks.EventType("on-model-version-created") + OnModelVersionUpdated = webhooks.EventType("on-model-version-updated") + OnModelVersionDeleted = webhooks.EventType("on-model-version-deleted") + + OnVersionEndpointPredeployment = webhooks.EventType("on-version-endpoint-predeployment") + OnVersionEndpointDeployed = webhooks.EventType("on-version-endpoint-deployed") + OnVersionEndpointUndeployed = webhooks.EventType("on-version-endpoint-undeployed") +) + +var events = []webhooks.EventType{ + OnModelCreated, + OnModelEndpointCreated, + OnModelEndpointUpdated, + OnModelEndpointDeleted, + OnModelVersionCreated, + OnModelVersionUpdated, + OnModelVersionDeleted, + OnVersionEndpointPredeployment, + OnVersionEndpointDeployed, + OnVersionEndpointUndeployed, +} + +type Webhook struct { + webhookManager webhookManager.WebhookManager +} + +type Option func(*config) + +type config struct { + successCallback func(payload []byte) error + errorCallback func(error) error + body map[string]interface{} +} + +type Client interface { + TriggerWebhooks(ctx context.Context, event webhooks.EventType, opts ...Option) error +} + +func NewWebhook(cfg *webhookManager.Config) *Webhook { + manager, err := webhookManager.InitializeWebhooks(cfg, events) + if err != nil { + log.Panicf("failed to initialize webhook: %s", err) + } + + return &Webhook{ + webhookManager: manager, + } +} + +func newDefaultOption() *config { + return &config{ + successCallback: webhooks.NoOpCallback, + errorCallback: webhooks.NoOpErrorHandler, + } +} + +func (w Webhook) TriggerWebhooks(ctx context.Context, event webhooks.EventType, opts ...Option) error { + if !w.isEventConfigured(event) { + return nil + } + + conf := newDefaultOption() + for _, opt := range opts { + opt(conf) + } + + b := &WebhookRequest{ + EventType: event, + Data: conf.body, + } + + return w.webhookManager.InvokeWebhooks(ctx, event, b, conf.successCallback, conf.errorCallback) +} + +func (w Webhook) isEventConfigured(event webhooks.EventType) bool { + return w.webhookManager != nil && w.webhookManager.IsEventConfigured(event) +} + +func SetSuccessCallBack(f func(payload []byte) error) Option { + return func(c *config) { + c.successCallback = f + } +} + +func SetErrorCallback(f func(error) error) Option { + return func(c *config) { + c.errorCallback = f + } +} + +func SetBody(items ...interface{}) Option { + body := make(map[string]interface{}) + + for _, item := range items { + switch item.(type) { + case *models.Model: + body["model"] = item + case *models.Version: + body["model_version"] = item + case *models.ModelEndpoint: + body["model_endpoint"] = item + case *models.VersionEndpoint: + body["version_endpoint"] = item + } + } + + return func(c *config) { + c.body = body + } +} diff --git a/api/webhook/webhook_test.go b/api/webhook/webhook_test.go new file mode 100644 index 000000000..aa5eb7b94 --- /dev/null +++ b/api/webhook/webhook_test.go @@ -0,0 +1,102 @@ +package webhook + +import ( + "context" + "fmt" + "testing" + + "github.com/caraml-dev/merlin/models" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestTriggerWebhooks(t *testing.T) { + defaultModel := &models.Model{ + ID: 2, + ProjectID: 1, + Name: "my-model", + } + + testCases := []struct { + name string + event webhookManager.EventType + webhook func() *webhookManager.MockWebhookManager + call func(w *Webhook) error + wantError bool + }{ + { + name: "Success: event is configured", + event: OnModelCreated, + webhook: func() *webhookManager.MockWebhookManager { + w := webhookManager.NewMockWebhookManager(t) + w.On("IsEventConfigured", OnModelCreated).Return(true) + w.On("InvokeWebhooks", mock.Anything, OnModelCreated, mock.Anything, mock.Anything, mock.Anything).Return(nil) + return w + }, + wantError: false, + }, + { + name: "Success: event is not configured", + event: OnModelCreated, + webhook: func() *webhookManager.MockWebhookManager { + w := webhookManager.NewMockWebhookManager(t) + w.On("IsEventConfigured", OnModelCreated).Return(false) + return w + }, + wantError: false, + }, + { + name: "Success: with body set up", + event: OnModelCreated, + webhook: func() *webhookManager.MockWebhookManager { + body := &WebhookRequest{ + EventType: OnModelCreated, + Data: map[string]interface{}{ + "model": defaultModel, + }, + } + w := webhookManager.NewMockWebhookManager(t) + w.On("IsEventConfigured", OnModelCreated).Return(true) + w.On("InvokeWebhooks", mock.Anything, OnModelCreated, body, mock.Anything, mock.Anything).Return(nil) + return w + }, + call: func(w *Webhook) error { + return w.TriggerWebhooks(context.Background(), OnModelCreated, SetBody(defaultModel)) + }, + wantError: false, + }, + { + name: "Fail: there was a webhook error", + event: OnModelCreated, + webhook: func() *webhookManager.MockWebhookManager { + w := webhookManager.NewMockWebhookManager(t) + w.On("IsEventConfigured", OnModelCreated).Return(true) + w.On("InvokeWebhooks", mock.Anything, OnModelCreated, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("there was an error")) + return w + }, + wantError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := &Webhook{ + webhookManager: tc.webhook(), + } + + var err error + if tc.call != nil { + err = tc.call(w) + } else { + err = w.TriggerWebhooks(context.Background(), tc.event) + } + + if tc.wantError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + }) + } +} diff --git a/api/webhooks/webhooks.go b/api/webhooks/webhooks.go deleted file mode 100644 index 4c60e4969..000000000 --- a/api/webhooks/webhooks.go +++ /dev/null @@ -1,24 +0,0 @@ -package webhooks - -import ( - "github.com/caraml-dev/mlp/api/pkg/webhooks" - - "github.com/caraml-dev/merlin/models" -) - -var ( - OnModelVersionPredeployment = webhooks.EventType("on-model-version-predeployment") - OnModelVersionDeployed = webhooks.EventType("on-model-version-deployed") - OnModelVersionUndeployed = webhooks.EventType("on-model-version-undeployed") -) - -var WebhookEvents = []webhooks.EventType{ - OnModelVersionPredeployment, - OnModelVersionDeployed, - OnModelVersionUndeployed, -} - -type VersionEndpointRequest struct { - EventType webhooks.EventType `json:"event_type"` - VersionEndpoint *models.VersionEndpoint `json:"version_endpoint"` -}