From da2803ab384a83d352663100f299f2c0699836dd Mon Sep 17 00:00:00 2001 From: bthari Date: Mon, 19 Aug 2024 16:28:42 +0700 Subject: [PATCH] feat: add webhook to model version deployment & undeployment (#601) # Description When deploying or undeploying model endpoint, other entity might want to get a trigger to automate their process or for Merlin other process. This PR add webhooks call (based on the [webhook from MLP](https://github.com/caraml-dev/mlp/blob/main/api/pkg/webhooks/README.md)), so on model version deployment/undeployment it will call the configured webhooks. # Modifications **Main Changes:** Added webhooks call on - Model version pre-deployment: will ignore the success response (for this version) and will stop/fail the deployment version if any `async` webhook fail - Model version post-deployment: will ignore error, only log the error if any occur during the call - Model version post-undeployment: will ignore error, only log the error if any occur during the call Request payload to webhook: - Request: - `event_type`: name of event which triggers the webhook - `versionEndpoint`: object version endpoint **Side effect changes**: - With MLP update in go.mod, the assert function is also updated. Previously, the `assert.InEpsilon` can pass when item in actual slice is in expected slice, even though the expected slice might have more items; now the `testify/assert` will check the two slices length first ([ref](https://github.com/stretchr/testify/pull/1483/files)) -> Added some changes to fix the unit test in `TestToFloat64List` # Tests # Checklist - [x] Added PR label - [x] Added unit test, integration, and/or e2e tests - [x] Tested locally - [ ] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduces API changes # Release Notes ```release-note add configurable webhook call in endpoint deployment and undeployment ``` --- api/cmd/api/main.go | 12 +- api/cmd/api/setup.go | 7 +- api/config/config.go | 2 + api/config/config_test.go | 20 ++ api/config/testdata/base-configs-1.yaml | 12 + api/go.mod | 7 +- api/go.sum | 13 +- .../types/converter/converter_test.go | 2 +- api/queue/work/model_service_deployment.go | 16 ++ .../work/model_service_deployment_test.go | 107 ++++++++ api/service/version_endpoint_service.go | 33 +++ api/service/version_endpoint_service_test.go | 231 +++++++++++++++++- api/webhooks/webhooks.go | 24 ++ 13 files changed, 465 insertions(+), 21 deletions(-) create mode 100644 api/webhooks/webhooks.go diff --git a/api/cmd/api/main.go b/api/cmd/api/main.go index 35411300f..c25b987f4 100644 --- a/api/cmd/api/main.go +++ b/api/cmd/api/main.go @@ -27,8 +27,8 @@ import ( "time" mlflowDelete "github.com/caraml-dev/mlp/api/pkg/client/mlflow" - "github.com/caraml-dev/mlp/api/pkg/instrumentation/sentry" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/gorilla/mux" "github.com/heptiolabs/healthcheck" @@ -49,6 +49,7 @@ import ( "github.com/caraml-dev/merlin/service" "github.com/caraml-dev/merlin/storage" "github.com/caraml-dev/merlin/warden" + "github.com/caraml-dev/merlin/webhooks" "github.com/caraml-dev/mlp/api/pkg/authz/enforcer" "github.com/caraml-dev/mlp/api/pkg/instrumentation/newrelic" ) @@ -267,6 +268,11 @@ func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dis log.Panicf("invalid deployment label prefix (%s): %s", cfg.DeploymentLabelPrefix, err) } + webhookClient, err := webhookManager.InitializeWebhooks(&cfg.WebhooksConfig, webhooks.WebhookEvents) + if err != nil { + log.Panicf("failed to initialize webhooks: %s", err) + } + webServiceBuilder, predJobBuilder, imageBuilderJanitor := initImageBuilder(cfg) observabilityPublisherStorage := storage.NewObservabilityPublisherStorage(db) @@ -274,8 +280,8 @@ func buildDependencies(ctx context.Context, cfg *config.Config, db *gorm.DB, dis versionStorage := storage.NewVersionStorage(db) observabilityEvent := event.NewEventProducer(dispatcher, observabilityPublisherStorage, versionStorage) clusterControllers := initClusterControllers(cfg) - modelServiceDeployment := initModelServiceDeployment(cfg, webServiceBuilder, clusterControllers, db, observabilityEvent) - versionEndpointService := initVersionEndpointService(cfg, webServiceBuilder, clusterControllers, db, coreClient, dispatcher) + modelServiceDeployment := initModelServiceDeployment(cfg, webServiceBuilder, clusterControllers, db, observabilityEvent, webhookClient) + versionEndpointService := initVersionEndpointService(cfg, webServiceBuilder, clusterControllers, db, coreClient, dispatcher, webhookClient) modelEndpointService := initModelEndpointService(cfg, db, observabilityEvent) batchControllers := initBatchControllers(cfg, db, mlpAPIClient) diff --git a/api/cmd/api/setup.go b/api/cmd/api/setup.go index 4fabae5be..180885343 100644 --- a/api/cmd/api/setup.go +++ b/api/cmd/api/setup.go @@ -10,6 +10,7 @@ import ( "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/client/clientset/versioned" "github.com/caraml-dev/mlp/api/pkg/artifact" "github.com/caraml-dev/mlp/api/pkg/auth" + "github.com/caraml-dev/mlp/api/pkg/webhooks" feast "github.com/feast-dev/feast/sdk/go" "github.com/feast-dev/feast/sdk/go/protos/feast/core" "google.golang.org/grpc" @@ -421,7 +422,7 @@ func initPredictionJobService(cfg *config.Config, controllers map[string]batch.C return service.NewPredictionJobService(controllers, builder, predictionJobStorage, clock.RealClock{}, cfg.Environment, producer) } -func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, observabilityEvent event.EventProducer) *work.ModelServiceDeployment { +func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, observabilityEvent event.EventProducer, webhookManager webhooks.WebhookManager) *work.ModelServiceDeployment { return &work.ModelServiceDeployment{ ClusterControllers: controllers, ImageBuilder: builder, @@ -430,6 +431,7 @@ func initModelServiceDeployment(cfg *config.Config, builder imagebuilder.ImageBu LoggerDestinationURL: cfg.LoggerDestinationURL, MLObsLoggerDestinationURL: cfg.MLObsLoggerDestinationURL, ObservabilityEventProducer: observabilityEvent, + WebhookManager: webhookManager, } } @@ -502,7 +504,7 @@ func initClusterControllers(cfg *config.Config) map[string]cluster.Controller { return controllers } -func initVersionEndpointService(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, feastCoreClient core.CoreServiceClient, producer queue.Producer) service.EndpointsService { +func initVersionEndpointService(cfg *config.Config, builder imagebuilder.ImageBuilder, controllers map[string]cluster.Controller, db *gorm.DB, feastCoreClient core.CoreServiceClient, producer queue.Producer, webhookManager webhooks.WebhookManager) service.EndpointsService { return service.NewEndpointService(service.EndpointServiceParams{ ClusterControllers: controllers, ImageBuilder: builder, @@ -514,6 +516,7 @@ func initVersionEndpointService(cfg *config.Config, builder imagebuilder.ImageBu JobProducer: producer, FeastCoreClient: feastCoreClient, StandardTransformerConfig: cfg.StandardTransformerConfig, + WebhookManager: webhookManager, }) } diff --git a/api/config/config.go b/api/config/config.go index 6c6996c3b..01127598e 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -24,6 +24,7 @@ import ( mlpcluster "github.com/caraml-dev/mlp/api/pkg/cluster" "github.com/caraml-dev/mlp/api/pkg/instrumentation/newrelic" "github.com/caraml-dev/mlp/api/pkg/instrumentation/sentry" + "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/go-playground/validator/v10" "github.com/mitchellh/mapstructure" "github.com/ory/viper" @@ -70,6 +71,7 @@ type Config struct { PyFuncPublisherConfig PyFuncPublisherConfig InferenceServiceDefaults InferenceServiceDefaults ObservabilityPublisher ObservabilityPublisher + WebhooksConfig webhooks.Config } // UIConfig stores the configuration for the UI. diff --git a/api/config/config_test.go b/api/config/config_test.go index cf494faa4..230184bc7 100644 --- a/api/config/config_test.go +++ b/api/config/config_test.go @@ -25,6 +25,7 @@ import ( mlpcluster "github.com/caraml-dev/mlp/api/pkg/cluster" "github.com/caraml-dev/mlp/api/pkg/instrumentation/newrelic" "github.com/caraml-dev/mlp/api/pkg/instrumentation/sentry" + "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/durationpb" @@ -598,6 +599,25 @@ func TestLoad(t *testing.T) { }, DeploymentTimeout: 30 * time.Minute, }, + WebhooksConfig: webhooks.Config{ + Enabled: true, + Config: map[webhooks.EventType][]webhooks.WebhookConfig{ + "on-model-deployed": { + { + Name: "sync-webhooks", + URL: "http://127.0.0.1:8000/sync-webhook", + Method: "POST", + FinalResponse: true, + }, + { + Name: "async-webhooks", + URL: "http://127.0.0.1:8000/async-webhook", + Method: "POST", + Async: true, + }, + }, + }, + }, }, }, "missing file": { diff --git a/api/config/testdata/base-configs-1.yaml b/api/config/testdata/base-configs-1.yaml index 65dfc7ceb..6db2b10de 100644 --- a/api/config/testdata/base-configs-1.yaml +++ b/api/config/testdata/base-configs-1.yaml @@ -149,3 +149,15 @@ InferenceServiceDefaults: DefaultEnvVarsWithoutCPULimits: - Name: foo Value: bar +WebhooksConfig: + Enabled: true + Config: + On-Model-Deployed: + - URL: http://127.0.0.1:8000/sync-webhook + Method: POST + FinalResponse: true + Name: sync-webhooks + - URL: http://127.0.0.1:8000/async-webhook + Method: POST + Name: async-webhooks + Async: true diff --git a/api/go.mod b/api/go.mod index 1cdfd70b2..1f9c67111 100644 --- a/api/go.mod +++ b/api/go.mod @@ -12,7 +12,7 @@ require ( github.com/bboughton/gcp-helpers v0.1.0 github.com/buger/jsonparser v1.1.1 github.com/caraml-dev/merlin-pyspark-app v0.0.3 - github.com/caraml-dev/mlp v1.12.2-0.20240517121307-b89dab536aab + github.com/caraml-dev/mlp v1.13.2-rc2 github.com/caraml-dev/protopath v0.1.0 github.com/caraml-dev/universal-prediction-interface v1.0.0 github.com/cenkalti/backoff/v4 v4.2.1 @@ -64,7 +64,7 @@ require ( github.com/rs/cors v1.8.2 github.com/soheilhy/cmux v0.1.5 github.com/spaolacci/murmur3 v1.1.0 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/xanzy/go-gitlab v0.32.0 go.opencensus.io v0.24.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 @@ -213,7 +213,7 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.8.1 // indirect - github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 // indirect github.com/valyala/fastjson v1.6.3 // indirect @@ -246,6 +246,7 @@ require ( ) require ( + github.com/avast/retry-go/v4 v4.6.0 // indirect golang.org/x/time v0.5.0 // indirect k8s.io/klog/v2 v2.120.1 // indirect k8s.io/kube-openapi v0.0.0-20231113174909-778a5567bc1e // indirect diff --git a/api/go.sum b/api/go.sum index 87e05793a..f608b5513 100644 --- a/api/go.sum +++ b/api/go.sum @@ -137,6 +137,8 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA= +github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE= github.com/aws/aws-sdk-go v1.17.7/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.50.0 h1:HBtrLeO+QyDKnc3t1+5DR1RxodOHCGr8ZcrHudpv7jI= github.com/aws/aws-sdk-go v1.50.0/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= @@ -159,8 +161,8 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= -github.com/caraml-dev/mlp v1.12.2-0.20240517121307-b89dab536aab h1:+XKM4kEBZz1gEbOHrphso6HxmMGSfss9TyMBIE0hm2M= -github.com/caraml-dev/mlp v1.12.2-0.20240517121307-b89dab536aab/go.mod h1:Zdz4bALO9WOHXhOgsoLmCjMCJnDVEZEnQFg8rk+u2cE= +github.com/caraml-dev/mlp v1.13.2-rc2 h1:Zmyoy3OTPv2fU+42rxMwUt9erS9J6QA0nlZQy/xCPtk= +github.com/caraml-dev/mlp v1.13.2-rc2/go.mod h1:jKfnUEpCcARv/aJF6qH7vT7VMKICDVOq/pDFvj6V3vQ= github.com/caraml-dev/protopath v0.1.0 h1:hjJ/U9RGD6QZ0Ee9SIYbVmwPugps4S5EpL6R+5ZrBe0= github.com/caraml-dev/protopath v0.1.0/go.mod h1:hVA2HkTrMYv+Q57gtrzu9/P7EXlNtBUcTz43z6EE010= github.com/caraml-dev/universal-prediction-interface v1.0.0 h1:3Z6adv1XZnBVRzFIeCu3mPcPnJrdB5IByYfdD9K/atI= @@ -998,8 +1000,9 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -1010,8 +1013,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stvp/go-udp-testing v0.0.0-20201019212854-469649b16807/go.mod h1:7jxmlfBCDBXRzr0eAQJ48XC1hBu1np4CS5+cHEYfwpc= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= diff --git a/api/pkg/transformer/types/converter/converter_test.go b/api/pkg/transformer/types/converter/converter_test.go index 5f17e5913..e71c648b7 100644 --- a/api/pkg/transformer/types/converter/converter_test.go +++ b/api/pkg/transformer/types/converter/converter_test.go @@ -1882,7 +1882,7 @@ func TestToFloat64List(t *testing.T) { { name: "from []float32", args: args{ - val: []float32{float32(3.14), float32(math.NaN())}, + val: []float32{float32(3.14), float32(4.56), float32(math.NaN())}, }, want: []float64{3.14, 4.56}, wantErr: false, diff --git a/api/queue/work/model_service_deployment.go b/api/queue/work/model_service_deployment.go index 571e87577..a34b4776c 100644 --- a/api/queue/work/model_service_deployment.go +++ b/api/queue/work/model_service_deployment.go @@ -15,6 +15,8 @@ import ( "github.com/caraml-dev/merlin/pkg/observability/event" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/storage" + "github.com/caraml-dev/merlin/webhooks" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/prometheus/client_golang/prometheus" "gorm.io/gorm" ) @@ -42,6 +44,7 @@ type ModelServiceDeployment struct { LoggerDestinationURL string MLObsLoggerDestinationURL string ObservabilityEventProducer event.EventProducer + WebhookManager webhookManager.WebhookManager } type EndpointJob struct { @@ -208,6 +211,19 @@ func (depl *ModelServiceDeployment) Deploy(job *queue.Job) error { } } + // calling webhooks if there's any webhooks configured + if depl.WebhookManager != nil && depl.WebhookManager.IsEventConfigured(webhooks.OnModelVersionDeployed) { + body := &webhooks.VersionEndpointRequest{ + EventType: webhooks.OnModelVersionDeployed, + VersionEndpoint: endpoint, + } + + err = depl.WebhookManager.InvokeWebhooks(ctx, webhooks.OnModelVersionDeployed, body, webhookManager.NoOpCallback, webhookManager.NoOpErrorHandler) + if err != nil { + log.Warnf("unable to invoke webhooks for event type: %s, model: %s, version: %s, error: %v", webhooks.OnModelVersionDeployed, model.Name, version.ID, err) + } + } + return nil } diff --git a/api/queue/work/model_service_deployment_test.go b/api/queue/work/model_service_deployment_test.go index 8983a29f5..f3e95895b 100644 --- a/api/queue/work/model_service_deployment_test.go +++ b/api/queue/work/model_service_deployment_test.go @@ -14,6 +14,8 @@ import ( eventMock "github.com/caraml-dev/merlin/pkg/observability/event/mocks" "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/storage/mocks" + webhook "github.com/caraml-dev/merlin/webhooks" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "gorm.io/gorm" @@ -72,6 +74,7 @@ func TestExecuteDeployment(t *testing.T) { storage func() *mocks.VersionEndpointStorage controller func() *clusterMock.Controller imageBuilder func() *imageBuilderMock.ImageBuilder + webhookManager func() webhookManager.WebhookManager eventProducer *eventMock.EventProducer }{ { @@ -117,6 +120,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Success: Default - Model Observability Supported", @@ -162,6 +170,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, eventProducer: func() *eventMock.EventProducer { producer := &eventMock.EventProducer{} producer.On("VersionEndpointChangeEvent", &models.VersionEndpoint{ @@ -178,6 +191,57 @@ func TestExecuteDeployment(t *testing.T) { return producer }(), }, + { + name: "Success: with calling webhooks", + model: model, + version: version, + endpoint: &models.VersionEndpoint{ + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + }, + deploymentStorage: func() *mocks.DeploymentStorage { + mockStorage := createDefaultMockDeploymentStorage() + mockStorage.On("OnDeploymentSuccess", mock.Anything).Return(nil) + return mockStorage + }, + storage: func() *mocks.VersionEndpointStorage { + mockStorage := &mocks.VersionEndpointStorage{} + mockStorage.On("Save", mock.Anything).Return(nil) + mockStorage.On("Get", mock.Anything).Return(&models.VersionEndpoint{ + Environment: env, + EnvironmentName: env.Name, + ResourceRequest: env.DefaultResourceRequest, + VersionID: version.ID, + Namespace: project.Name, + }, nil) + return mockStorage + }, + controller: func() *clusterMock.Controller { + ctrl := &clusterMock.Controller{} + ctrl.On("Deploy", mock.Anything, mock.Anything). + Return(&models.Service{ + Name: iSvcName, + Namespace: project.Name, + ServiceName: svcName, + URL: url, + Metadata: svcMetadata, + }, nil) + return ctrl + }, + imageBuilder: func() *imageBuilderMock.ImageBuilder { + mockImgBuilder := &imageBuilderMock.ImageBuilder{} + return mockImgBuilder + }, + webhookManager: func() webhookManager.WebhookManager { + manager := webhookManager.NewMockWebhookManager(t) + + manager.On("IsEventConfigured", mock.Anything).Return(true) + manager.On("InvokeWebhooks", mock.Anything, webhook.OnModelVersionDeployed, mock.IsType(&webhook.VersionEndpointRequest{}), mock.Anything, mock.Anything).Return(nil) + return manager + }, + }, { name: "Success eventhough error when produce event", model: &models.Model{Name: "model", Project: project, ObservabilitySupported: true}, @@ -222,6 +286,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, eventProducer: func() *eventMock.EventProducer { producer := &eventMock.EventProducer{} producer.On("VersionEndpointChangeEvent", &models.VersionEndpoint{ @@ -289,6 +358,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Success: Latest deployment entry in storage not in pending state", @@ -341,6 +415,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Success: Pytorch Model", @@ -385,6 +464,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Success: empty pyfunc model", @@ -431,6 +515,11 @@ func TestExecuteDeployment(t *testing.T) { Return("gojek/mymodel-1:latest", nil) return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Success: pytorch model with transformer", @@ -477,6 +566,11 @@ func TestExecuteDeployment(t *testing.T) { Return("gojek/mymodel-1:latest", nil) return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Success: Default With GPU", @@ -535,6 +629,11 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + webhookManager := webhookManager.NewMockWebhookManager(t) + webhookManager.On("IsEventConfigured", mock.Anything).Return(false) + return webhookManager + }, }, { name: "Failed: deployment failed", @@ -573,6 +672,9 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder := &imageBuilderMock.ImageBuilder{} return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + return webhookManager.NewMockWebhookManager(t) + }, }, { name: "Failed: image builder failed", @@ -610,6 +712,9 @@ func TestExecuteDeployment(t *testing.T) { mockImgBuilder.On("BuildImage", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("Failed to build image")) return mockImgBuilder }, + webhookManager: func() webhookManager.WebhookManager { + return webhookManager.NewMockWebhookManager(t) + }, }, } for _, tt := range tests { @@ -619,6 +724,7 @@ func TestExecuteDeployment(t *testing.T) { imgBuilder := tt.imageBuilder() mockStorage := tt.storage() mockDeploymentStorage := tt.deploymentStorage() + mockWebhook := tt.webhookManager() job := &queue.Job{ Name: "job", Arguments: queue.Arguments{ @@ -637,6 +743,7 @@ func TestExecuteDeployment(t *testing.T) { DeploymentStorage: mockDeploymentStorage, LoggerDestinationURL: loggerDestinationURL, ObservabilityEventProducer: tt.eventProducer, + WebhookManager: mockWebhook, } err := svc.Deploy(job) diff --git a/api/service/version_endpoint_service.go b/api/service/version_endpoint_service.go index ea34264a0..bfd712232 100644 --- a/api/service/version_endpoint_service.go +++ b/api/service/version_endpoint_service.go @@ -21,6 +21,7 @@ import ( "github.com/caraml-dev/merlin/cluster" "github.com/caraml-dev/merlin/config" + "github.com/caraml-dev/merlin/log" "github.com/caraml-dev/merlin/models" "github.com/caraml-dev/merlin/pkg/autoscaling" "github.com/caraml-dev/merlin/pkg/deployment" @@ -32,6 +33,8 @@ import ( "github.com/caraml-dev/merlin/queue" "github.com/caraml-dev/merlin/queue/work" "github.com/caraml-dev/merlin/storage" + "github.com/caraml-dev/merlin/webhooks" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/feast-dev/feast/sdk/go/protos/feast/core" "github.com/google/uuid" "google.golang.org/protobuf/encoding/protojson" @@ -66,6 +69,7 @@ type EndpointServiceParams struct { JobProducer queue.Producer FeastCoreClient core.CoreServiceClient StandardTransformerConfig config.StandardTransformerConfig + WebhookManager webhookManager.WebhookManager } type endpointService struct { @@ -80,6 +84,7 @@ type endpointService struct { jobProducer queue.Producer feastCoreClient core.CoreServiceClient standardTransformerConfig config.StandardTransformerConfig + webhookManager webhookManager.WebhookManager } func NewEndpointService(params EndpointServiceParams) EndpointsService { @@ -95,6 +100,7 @@ func NewEndpointService(params EndpointServiceParams) EndpointsService { jobProducer: params.JobProducer, feastCoreClient: params.FeastCoreClient, standardTransformerConfig: params.StandardTransformerConfig, + webhookManager: params.WebhookManager, } } @@ -132,6 +138,20 @@ func (k *endpointService) DeployEndpoint(ctx context.Context, environment *model return nil, err } + // calling webhooks if there's any webhooks configured + if k.webhookManager != nil && k.webhookManager.IsEventConfigured(webhooks.OnModelVersionPredeployment) { + body := &webhooks.VersionEndpointRequest{ + EventType: webhooks.OnModelVersionPredeployment, + VersionEndpoint: endpoint, + } + + err := k.webhookManager.InvokeWebhooks(ctx, webhooks.OnModelVersionPredeployment, body, webhookManager.NoOpCallback, webhookManager.NoOpErrorHandler) + if err != nil { + log.Errorf("unable to invoke webhooks for event type: %s, model: %s, version: %s, error: %v", webhooks.OnModelVersionPredeployment, model.Name, version.ID, err) + return nil, err + } + } + // Copy to avoid race condition tobeDeployedEndpoint := *endpoint @@ -281,6 +301,19 @@ func (k *endpointService) UndeployEndpoint(ctx context.Context, environment *mod return nil, err } + // calling webhooks if there's any webhooks configured + if k.webhookManager != nil && k.webhookManager.IsEventConfigured(webhooks.OnModelVersionUndeployed) { + body := &webhooks.VersionEndpointRequest{ + EventType: webhooks.OnModelVersionUndeployed, + VersionEndpoint: endpoint, + } + + err = k.webhookManager.InvokeWebhooks(ctx, webhooks.OnModelVersionUndeployed, body, webhookManager.NoOpCallback, webhookManager.NoOpErrorHandler) + if err != nil { + log.Warnf("unable to invoke webhooks for event type: %s, model: %s, version: %s, error: %v", webhooks.OnModelVersionUndeployed, model.Name, version.ID, err) + } + } + return endpoint, nil } diff --git a/api/service/version_endpoint_service_test.go b/api/service/version_endpoint_service_test.go index 3853eef5e..362021ae8 100644 --- a/api/service/version_endpoint_service_test.go +++ b/api/service/version_endpoint_service_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + webhookManager "github.com/caraml-dev/mlp/api/pkg/webhooks" "github.com/feast-dev/feast/sdk/go/protos/feast/core" "github.com/feast-dev/feast/sdk/go/protos/feast/types" "github.com/google/uuid" @@ -47,6 +48,7 @@ import ( "github.com/caraml-dev/merlin/pkg/transformer/spec" queueMock "github.com/caraml-dev/merlin/queue/mocks" "github.com/caraml-dev/merlin/storage/mocks" + webhooks "github.com/caraml-dev/merlin/webhooks" ) var ( @@ -56,10 +58,11 @@ var ( func TestDeployEndpoint(t *testing.T) { type args struct { - environment *models.Environment - model *models.Model - version *models.Version - endpoint *models.VersionEndpoint + environment *models.Environment + model *models.Model + version *models.Version + endpoint *models.VersionEndpoint + isWebhookExist bool } env := &models.Environment{ @@ -79,13 +82,12 @@ func TestDeployEndpoint(t *testing.T) { model := &models.Model{Name: "model", Project: project} version := &models.Version{ID: 1} - // iSvcName := fmt.Sprintf("%s-%d-0", model.Name, version.ID) - tests := []struct { name string args args expectedEndpoint *models.VersionEndpoint - wantDeployError bool + + wantDeployError bool }{ { name: "success: new endpoint default resource request", @@ -94,6 +96,7 @@ func TestDeployEndpoint(t *testing.T) { model, version, &models.VersionEndpoint{}, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -120,6 +123,7 @@ func TestDeployEndpoint(t *testing.T) { MemoryRequest: resource.MustParse("1Gi"), }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -144,6 +148,7 @@ func TestDeployEndpoint(t *testing.T) { &models.Model{Name: "model", Project: project, Type: models.ModelTypePyTorch}, &models.Version{ID: 1}, &models.VersionEndpoint{}, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -165,6 +170,7 @@ func TestDeployEndpoint(t *testing.T) { &models.VersionEndpoint{ ResourceRequest: env.DefaultResourceRequest, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -186,6 +192,7 @@ func TestDeployEndpoint(t *testing.T) { &models.VersionEndpoint{ ResourceRequest: env.DefaultResourceRequest, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -223,6 +230,7 @@ func TestDeployEndpoint(t *testing.T) { }, Protocol: protocol.HttpJson, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -252,6 +260,7 @@ func TestDeployEndpoint(t *testing.T) { model, version, &models.VersionEndpoint{}, + false, }, expectedEndpoint: &models.VersionEndpoint{}, wantDeployError: true, @@ -269,6 +278,7 @@ func TestDeployEndpoint(t *testing.T) { ResourceRequest: env.DefaultResourceRequest, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -311,6 +321,7 @@ func TestDeployEndpoint(t *testing.T) { }, Protocol: protocol.HttpJson, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -359,6 +370,7 @@ func TestDeployEndpoint(t *testing.T) { }, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -403,6 +415,7 @@ func TestDeployEndpoint(t *testing.T) { }, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -451,6 +464,7 @@ func TestDeployEndpoint(t *testing.T) { }, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -504,6 +518,7 @@ func TestDeployEndpoint(t *testing.T) { }, DeploymentMode: deployment.RawDeploymentMode, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ Namespace: project.Name, @@ -561,6 +576,7 @@ func TestDeployEndpoint(t *testing.T) { TargetValue: 100, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ Namespace: project.Name, @@ -621,6 +637,7 @@ func TestDeployEndpoint(t *testing.T) { TargetValue: 100, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ Namespace: project.Name, @@ -725,6 +742,7 @@ func TestDeployEndpoint(t *testing.T) { TargetValue: 10, }, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ Namespace: project.Name, @@ -771,6 +789,7 @@ func TestDeployEndpoint(t *testing.T) { ResourceRequest: env.DefaultResourceRequest, Protocol: protocol.UpiV1, }, + false, }, expectedEndpoint: &models.VersionEndpoint{ DeploymentMode: deployment.ServerlessDeploymentMode, @@ -783,6 +802,26 @@ func TestDeployEndpoint(t *testing.T) { }, wantDeployError: false, }, + { + name: "success: new endpoint with existing webhookManager", + args: args{ + env, + model, + version, + &models.VersionEndpoint{}, + true, + }, + expectedEndpoint: &models.VersionEndpoint{ + DeploymentMode: deployment.ServerlessDeploymentMode, + AutoscalingPolicy: autoscaling.DefaultServerlessAutoscalingPolicy, + ResourceRequest: env.DefaultResourceRequest, + Namespace: project.Name, + URL: "", + Status: models.EndpointPending, + Protocol: protocol.HttpJson, + }, + wantDeployError: false, + }, } for _, tt := range tests { @@ -798,8 +837,15 @@ func TestDeployEndpoint(t *testing.T) { imgBuilder := &imageBuilderMock.ImageBuilder{} mockStorage := &mocks.VersionEndpointStorage{} mockDeploymentStorage := &mocks.DeploymentStorage{} + mockWebhook := webhookManager.NewMockWebhookManager(t) + mockStorage.On("Save", mock.Anything).Return(nil) mockDeploymentStorage.On("Save", mock.Anything).Return(nil, nil) + mockWebhook.On("IsEventConfigured", webhooks.OnModelVersionPredeployment).Return(tt.args.isWebhookExist) + if tt.args.isWebhookExist { + mockWebhook.On("InvokeWebhooks", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + } + mockCfg := &config.Config{ Environment: "dev", FeatureToggleConfig: config.FeatureToggleConfig{ @@ -820,6 +866,7 @@ func TestDeployEndpoint(t *testing.T) { MonitoringConfig: mockCfg.FeatureToggleConfig.MonitoringConfig, LoggerDestinationURL: loggerDestinationURL, JobProducer: mockQueueProducer, + WebhookManager: mockWebhook, }) actualEndpoint, err := endpointSvc.DeployEndpoint(context.Background(), tt.args.environment, tt.args.model, tt.args.version, tt.args.endpoint) if tt.wantDeployError { @@ -851,6 +898,10 @@ func TestDeployEndpoint(t *testing.T) { if tt.args.endpoint.Transformer != nil { assert.Equal(t, tt.args.endpoint.Transformer.Enabled, actualEndpoint.Transformer.Enabled) } + + if tt.args.isWebhookExist { + mockWebhook.AssertNumberOfCalls(t, "InvokeWebhooks", 1) + } }) } } @@ -2334,3 +2385,169 @@ func assertElementMatchFeatureTableMetadata(t *testing.T, expectation []*spec.Fe } assert.True(t, len(expectation) == numOfMatchElements) } + +func TestUndeployEndpoint(t *testing.T) { + type webhooksArgs struct { + event webhookManager.EventType + isEventConfigured bool + err error + } + + type args struct { + endpoint *models.VersionEndpoint + webhooks *webhooksArgs + } + + id := uuid.New() + + env := &models.Environment{ + Name: "env1", + Cluster: "cluster1", + IsDefault: &isDefaultTrue, + Region: "id", + GcpProject: "project", + DefaultResourceRequest: &models.ResourceRequest{ + MinReplica: 0, + MaxReplica: 1, + CPURequest: resource.MustParse("1"), + MemoryRequest: resource.MustParse("1Gi"), + }, + } + project := mlp.Project{Name: "project"} + model := &models.Model{Name: "model", Project: project} + version := &models.Version{ID: 1} + + tests := []struct { + name string + args args + expectedEndpoint *models.VersionEndpoint + wantUndeployError bool + }{ + { + name: "success: without webhookManager", + args: args{ + &models.VersionEndpoint{ + ID: id, + Namespace: project.Name, + Status: models.EndpointRunning, + }, + &webhooksArgs{ + event: webhooks.OnModelVersionUndeployed, + isEventConfigured: false, + err: nil, + }, + }, + expectedEndpoint: &models.VersionEndpoint{ + ID: id, + Namespace: project.Name, + Status: models.EndpointTerminated, + }, + wantUndeployError: false, + }, + { + name: "success: with webhookManager", + args: args{ + &models.VersionEndpoint{ + ID: id, + Namespace: project.Name, + Status: models.EndpointRunning, + }, + &webhooksArgs{ + event: webhooks.OnModelVersionUndeployed, + isEventConfigured: true, + }, + }, + expectedEndpoint: &models.VersionEndpoint{ + ID: id, + Namespace: project.Name, + Status: models.EndpointTerminated, + }, + }, + { + name: "fail: to undeploy", + args: args{ + &models.VersionEndpoint{ + ID: id, + Namespace: project.Name, + Status: models.EndpointRunning, + }, + &webhooksArgs{ + event: webhooks.OnModelVersionUndeployed, + isEventConfigured: false, + }, + }, + expectedEndpoint: &models.VersionEndpoint{ + ID: id, + Namespace: project.Name, + Status: models.EndpointTerminated, + }, + wantUndeployError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + envController := &clusterMock.Controller{} + mockQueueProducer := &queueMock.Producer{} + imgBuilder := &imageBuilderMock.ImageBuilder{} + mockStorage := &mocks.VersionEndpointStorage{} + mockDeploymentStorage := &mocks.DeploymentStorage{} + mockWebhook := webhookManager.NewMockWebhookManager(t) + + envController.On("Delete", mock.Anything, mock.Anything).Return(nil, nil) + mockStorage.On("Save", mock.Anything).Return(nil) + if tt.wantUndeployError { + mockDeploymentStorage.On("Undeploy", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("Failed to undeploy")) + } else { + mockDeploymentStorage.On("Undeploy", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockWebhook.On("IsEventConfigured", mock.Anything).Return(tt.args.webhooks.isEventConfigured) + } + if tt.args.webhooks.isEventConfigured { + mockWebhook.On("InvokeWebhooks", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + } + + mockCfg := &config.Config{ + Environment: "dev", + FeatureToggleConfig: config.FeatureToggleConfig{ + MonitoringConfig: config.MonitoringConfig{ + MonitoringEnabled: false, + }, + }, + } + + controllers := map[string]cluster.Controller{env.Name: envController} + + endpointSvc := NewEndpointService(EndpointServiceParams{ + ClusterControllers: controllers, + ImageBuilder: imgBuilder, + Storage: mockStorage, + DeploymentStorage: mockDeploymentStorage, + Environment: mockCfg.Environment, + MonitoringConfig: mockCfg.FeatureToggleConfig.MonitoringConfig, + LoggerDestinationURL: loggerDestinationURL, + JobProducer: mockQueueProducer, + WebhookManager: mockWebhook, + }) + + actualEndpoint, err := endpointSvc.UndeployEndpoint(context.Background(), env, model, version, tt.args.endpoint) + + envController.AssertNumberOfCalls(t, "Delete", 1) + mockStorage.AssertNumberOfCalls(t, "Save", 1) + mockDeploymentStorage.AssertNumberOfCalls(t, "Undeploy", 1) + + if tt.args.webhooks.isEventConfigured { + mockWebhook.AssertNumberOfCalls(t, "InvokeWebhooks", 1) + } + + if tt.wantUndeployError { + assert.Error(t, err) + return + } + + mockWebhook.AssertNumberOfCalls(t, "IsEventConfigured", 1) + + assert.NoError(t, err) + assert.Equal(t, tt.expectedEndpoint, actualEndpoint) + }) + } +} diff --git a/api/webhooks/webhooks.go b/api/webhooks/webhooks.go new file mode 100644 index 000000000..4c60e4969 --- /dev/null +++ b/api/webhooks/webhooks.go @@ -0,0 +1,24 @@ +package webhooks + +import ( + "github.com/caraml-dev/mlp/api/pkg/webhooks" + + "github.com/caraml-dev/merlin/models" +) + +var ( + OnModelVersionPredeployment = webhooks.EventType("on-model-version-predeployment") + OnModelVersionDeployed = webhooks.EventType("on-model-version-deployed") + OnModelVersionUndeployed = webhooks.EventType("on-model-version-undeployed") +) + +var WebhookEvents = []webhooks.EventType{ + OnModelVersionPredeployment, + OnModelVersionDeployed, + OnModelVersionUndeployed, +} + +type VersionEndpointRequest struct { + EventType webhooks.EventType `json:"event_type"` + VersionEndpoint *models.VersionEndpoint `json:"version_endpoint"` +}