diff --git a/.cov-ignore b/.cov-ignore new file mode 100644 index 00000000000..37035f002d6 --- /dev/null +++ b/.cov-ignore @@ -0,0 +1,4 @@ +zz_generated.deepcopy.go +openapi_generated.go +pkg\\/client +testing \ No newline at end of file diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3ac2344c869..a0f1b1879c3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -12,7 +12,7 @@ jobs: name: Build runs-on: ubuntu-latest steps: - + - name: Set up Go 1.x uses: actions/setup-go@v2 with: @@ -31,9 +31,27 @@ jobs: make fmt - name: Test + id: test run: | export GOPATH=/home/runner/go export PATH=$PATH:/usr/local/kubebuilder/bin:/home/runner/go/bin wget -O $GOPATH/bin/yq https://github.com/mikefarah/yq/releases/download/3.3.2/yq_linux_amd64 chmod +x $GOPATH/bin/yq make test + ./coverage.sh + echo ::set-output name=coverage::$(./coverage.sh | tr -s '\t' | cut -d$'\t' -f 3) + + - name: Print coverage + run: | + echo "Coverage output is ${{ steps.test.outputs.coverage }}" + + - name: Update coverage badge + if: github.ref == 'refs/heads/master' + uses: schneegans/dynamic-badges-action@v1.4.0 + with: + auth: ${{ secrets.GIST_SECRET }} + gistID: 5174bd748ac63a6e4803afea902e9810 + filename: coverage.json + label: coverage + message: ${{ steps.test.outputs.coverage }} + color: green diff --git a/Makefile b/Makefile index d0fd299f572..9fb919f1743 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ all: test manager agent router # Run tests test: fmt vet manifests envtest - KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) -p path)" go test $$(go list ./pkg/...) ./cmd/... -coverprofile coverage.out + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) -p path)" go test -v $$(go list ./pkg/...) ./cmd/... -coverprofile coverage.out -coverpkg ./pkg/... ./cmd... # Build manager binary manager: generate fmt vet lint diff --git a/README.md b/README.md index 1174123adaf..4ac4e17d14b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # KServe [![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/kserve/kserve) -[![Coverage Status](https://coveralls.io/repos/github/kserve/kserve/badge.svg?branch=master)](https://coveralls.io/github/kserve/kserve?branch=master) +[![Coverage Status](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/andyi2it/5174bd748ac63a6e4803afea902e9810/raw/coverage.json)](https://github.com/kserve/kserve/actions/workflows/go.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/kserve/kserve)](https://goreportcard.com/report/github.com/kserve/kserve) [![Releases](https://img.shields.io/github/release-pre/kserve/kserve.svg?sort=semver)](https://github.com/kserve/kserve/releases) [![LICENSE](https://img.shields.io/github/license/kserve/kserve.svg)](https://github.com/kserve/kserve/blob/master/LICENSE) diff --git a/coverage.sh b/coverage.sh new file mode 100755 index 00000000000..7b76e54b6af --- /dev/null +++ b/coverage.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +while read p || [ -n "$p" ] +do +sed -i "/${p}/d" ./coverage.out +done < ./.cov-ignore + +go tool cover -func coverage.out > coverage.cov + +tail -1 coverage.cov \ No newline at end of file diff --git a/pkg/agent/downloader_test.go b/pkg/agent/downloader_test.go new file mode 100644 index 00000000000..986f57e1f3d --- /dev/null +++ b/pkg/agent/downloader_test.go @@ -0,0 +1,102 @@ +/* +Copyright 2022 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package agent + +import ( + "io/ioutil" + logger "log" + "os" + + "github.com/kserve/kserve/pkg/agent/mocks" + "github.com/kserve/kserve/pkg/agent/storage" + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/modelconfig" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "go.uber.org/zap" +) + +var _ = Describe("Downloader", func() { + var modelDir string + var downloader *Downloader + BeforeEach(func() { + dir, err := ioutil.TempDir("", "example") + if err != nil { + logger.Fatal(err) + } + modelDir = dir + logger.Printf("Creating temp dir %v\n", modelDir) + zapLogger, _ := zap.NewProduction() + sugar := zapLogger.Sugar() + downloader = &Downloader{ + ModelDir: modelDir + "/test", + Providers: map[storage.Protocol]storage.Provider{ + storage.S3: &storage.S3Provider{ + Client: &mocks.MockS3Client{}, + Downloader: &mocks.MockS3Downloader{}, + }, + }, + Logger: sugar, + } + }) + AfterEach(func() { + os.RemoveAll(modelDir) + logger.Printf("Deleted temp dir %v\n", modelDir) + }) + + Context("When protocol is invalid", func() { + It("Should fail out and return error", func() { + modelConfig := modelconfig.ModelConfig{ + Name: "model1", + Spec: v1alpha1.ModelSpec{ + StorageURI: "sss://models/model1", + Framework: "sklearn", + }, + } + err := downloader.DownloadModel(modelConfig.Name, &modelConfig.Spec) + Expect(err).ShouldNot(BeNil()) + }) + }) + + Context("When storage uri is empty", func() { + It("Should fail out and return error", func() { + modelConfig := modelconfig.ModelConfig{ + Name: "model1", + Spec: v1alpha1.ModelSpec{ + StorageURI: "", + Framework: "sklearn", + }, + } + err := downloader.DownloadModel(modelConfig.Name, &modelConfig.Spec) + Expect(err).ShouldNot(BeNil()) + }) + }) + + Context("When storage uri is invalid", func() { + It("Should fail out and return error", func() { + modelConfig := modelconfig.ModelConfig{ + Name: "model1", + Spec: v1alpha1.ModelSpec{ + StorageURI: "s3:://models/model1", + Framework: "sklearn", + }, + } + err := downloader.DownloadModel(modelConfig.Name, &modelConfig.Spec) + Expect(err).ShouldNot(BeNil()) + }) + }) +}) diff --git a/pkg/agent/storage/utils_test.go b/pkg/agent/storage/utils_test.go index 9cb15fa1397..74b56ecdb5a 100644 --- a/pkg/agent/storage/utils_test.go +++ b/pkg/agent/storage/utils_test.go @@ -20,9 +20,11 @@ import ( "io/ioutil" "os" "path" + "path/filepath" "syscall" "testing" + "github.com/kserve/kserve/pkg/agent/mocks" "github.com/onsi/gomega" ) @@ -48,3 +50,63 @@ func TestCreate(t *testing.T) { expectedMode := os.FileMode(0777) g.Expect(mode.Perm()).To(gomega.Equal(expectedMode)) } + +func TestFileExists(t *testing.T) { + g := gomega.NewGomegaWithT(t) + syscall.Umask(0) + tmpDir, _ := ioutil.TempDir("", "test") + defer os.RemoveAll(tmpDir) + + // Test case for existing file + f, err := os.CreateTemp(tmpDir, "tmpfile") + g.Expect(err).To(gomega.BeNil()) + g.Expect(FileExists(f.Name())).To(gomega.BeTrue()) + f.Close() + + // Test case for not existing file + path := filepath.Join(tmpDir, "fileNotExist") + g.Expect(FileExists(path)).To(gomega.BeFalse()) + + // Test case for directory as filename + g.Expect(FileExists(tmpDir)).To(gomega.BeFalse()) +} + +func TestRemoveDir(t *testing.T) { + g := gomega.NewGomegaWithT(t) + syscall.Umask(0) + tmpDir, _ := ioutil.TempDir("", "test") + subDir, _ := ioutil.TempDir(tmpDir, "test") + os.CreateTemp(subDir, "tmp") + os.CreateTemp(tmpDir, "tmp") + + err := RemoveDir(tmpDir) + g.Expect(err).To(gomega.BeNil()) + _, err = os.Stat(tmpDir) + g.Expect(os.IsNotExist(err)).To(gomega.BeTrue()) + + // Test case for non existing directory + err = RemoveDir("directoryNotExist") + g.Expect(err).NotTo(gomega.BeNil()) +} + +func TestGetProvider(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + // When providers map already have specified provider + mockProviders := map[Protocol]Provider{ + S3: &S3Provider{ + Client: &mocks.MockS3Client{}, + Downloader: &mocks.MockS3Downloader{}, + }, + } + provider, err := GetProvider(mockProviders, S3) + g.Expect(err).To(gomega.BeNil()) + g.Expect(provider).Should(gomega.Equal(mockProviders[S3])) + + // When providers map does not have specified provider + for _, protocol := range SupportedProtocols { + provider, err = GetProvider(map[Protocol]Provider{}, protocol) + g.Expect(err).To(gomega.BeNil()) + g.Expect(provider).ShouldNot(gomega.BeNil()) + } +} diff --git a/pkg/agent/watcher_test.go b/pkg/agent/watcher_test.go index dfe5e2af1e8..bc96c20856d 100644 --- a/pkg/agent/watcher_test.go +++ b/pkg/agent/watcher_test.go @@ -18,6 +18,7 @@ package agent import ( "context" + "encoding/json" "fmt" "io/ioutil" logger "log" @@ -33,6 +34,7 @@ import ( "github.com/kserve/kserve/pkg/agent/mocks" "github.com/kserve/kserve/pkg/agent/storage" "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/constants" "github.com/kserve/kserve/pkg/modelconfig" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -63,7 +65,6 @@ var _ = Describe("Watcher", func() { It("should download and load the new models", func() { defer GinkgoRecover() logger.Printf("Sync model config using temp dir %v\n", modelDir) - watcher := NewWatcher("/tmp/configs", modelDir, sugar) modelConfigs := modelconfig.ModelConfigs{ { Name: "model1", @@ -82,7 +83,19 @@ var _ = Describe("Watcher", func() { }, }, } - watcher.parseConfig(modelConfigs, false) + _, err := os.Stat("/tmp/configs") + if os.IsNotExist(err) { + if err := os.MkdirAll("/tmp/configs", os.ModePerm); err != nil { + logger.Fatal(err, " Failed to create configs directory") + } + } + + file, _ := json.MarshalIndent(modelConfigs, "", " ") + if err := ioutil.WriteFile("/tmp/configs/"+constants.ModelConfigFileName, file, os.ModePerm); err != nil { + logger.Fatal(err, " Failed to write config files") + } + watcher := NewWatcher("/tmp/configs", modelDir, sugar) + puller := Puller{ channelMap: make(map[string]*ModelChannel), completions: make(chan *ModelOp, 4), @@ -100,12 +113,18 @@ var _ = Describe("Watcher", func() { }, logger: sugar, } + puller.waitGroup.wg.Add(len(watcher.ModelEvents)) go puller.processCommands(watcher.ModelEvents) + Eventually(func() int { return len(puller.channelMap) }).Should(Equal(0)) Eventually(func() int { return puller.opStats["model1"][Add] }).Should(Equal(1)) Eventually(func() int { return puller.opStats["model2"][Add] }).Should(Equal(1)) modelSpecMap, _ := SyncModelDir(modelDir+"/test1", watcher.logger) Expect(watcher.ModelTracker).Should(Equal(modelSpecMap)) + + DeferCleanup(func() { + os.RemoveAll("/tmp/configs") + }) }) }) }) diff --git a/pkg/apis/serving/v1alpha1/inference_graph_validation.go b/pkg/apis/serving/v1alpha1/inference_graph_validation.go index 414d8ed4ead..70f4db53f8d 100644 --- a/pkg/apis/serving/v1alpha1/inference_graph_validation.go +++ b/pkg/apis/serving/v1alpha1/inference_graph_validation.go @@ -30,6 +30,18 @@ import ( const ( // InvalidGraphNameFormatError defines the error message for invalid inference graph name InvalidGraphNameFormatError = "The InferenceGraph \"%s\" is invalid: a InferenceGraph name must consist of lower case alphanumeric characters or '-', and must start with alphabetical character. (e.g. \"my-name\" or \"abc-123\", regex used for validation is '%s')" + // RootNodeNotFoundError defines the error message for root node not found + RootNodeNotFoundError = "root node not found, InferenceGraph needs a node with name 'root' as the root node of the graph" + // WeightNotProvidedError defines the error message for traffic weight is nil for inference step + WeightNotProvidedError = "InferenceGraph[%s] Node[%s] Route[%s] missing the 'Weight'" + // InvalidWeightError defines the error message for sum of traffic weight is not 100 + InvalidWeightError = "InferenceGraph[%s] Node[%s] splitter node: the sum of traffic weights for all routing targets should be 100" + // DuplicateStepNameError defines the error message for more than one step contains same name + DuplicateStepNameError = "Node \"%s\" of InferenceGraph \"%s\" contains more than one step with name \"%s\"" + // TargetNotProvidedError defines the error message for inference graph target not specified + TargetNotProvidedError = "Step %d (\"%s\") in node \"%s\" of InferenceGraph \"%s\" does not specify an inference target" + // InvalidTargetError defines the error message for inference graph target specifies more than one of nodeName, serviceName, serviceUrl + InvalidTargetError = "Step %d (\"%s\") in node \"%s\" of InferenceGraph \"%s\" specifies more than one of nodeName, serviceName, serviceUrl" ) const ( @@ -95,7 +107,7 @@ func validateInferenceGraphStepNameUniqueness(ig *InferenceGraph) error { for _, route := range node.Steps { if route.StepName != "" { if nameSet.Has(route.StepName) { - return fmt.Errorf("Node \"%s\" of InferenceGraph \"%s\" contains more than one step with name \"%s\"", + return fmt.Errorf(DuplicateStepNameError, nodeName, ig.Name, route.StepName) } nameSet.Insert(route.StepName) @@ -122,12 +134,10 @@ func validateInferenceGraphSingleStepTargets(ig *InferenceGraph) error { count += 1 } if count == 0 { - return fmt.Errorf("Step %d (\"%s\") in node \"%s\" of InferenceGraph \"%s\" does not specify an inference target", - i, route.StepName, nodeName, ig.Name) + return fmt.Errorf(TargetNotProvidedError, i, route.StepName, nodeName, ig.Name) } if count != 1 { - return fmt.Errorf("Step %d (\"%s\") in node \"%s\" of InferenceGraph \"%s\" specifies more than one of nodeName, serviceName, serviceUrl", - i, route.StepName, nodeName, ig.Name) + return fmt.Errorf(InvalidTargetError, i, route.StepName, nodeName, ig.Name) } } } @@ -150,7 +160,7 @@ func validateInferenceGraphRouterRoot(ig *InferenceGraph) error { return nil } } - return fmt.Errorf("root node not found, InferenceGraph needs a node with name 'root' as the root node of the graph") + return fmt.Errorf(RootNodeNotFoundError) } // Validation of inference graph router type @@ -161,12 +171,12 @@ func validateInferenceGraphSplitterWeight(ig *InferenceGraph) error { if node.RouterType == Splitter { for _, route := range node.Steps { if route.Weight == nil { - return fmt.Errorf("InferenceGraph[%s] Node[%s] Route[%s] missing the 'Weight'", ig.Name, name, route.ServiceName) + return fmt.Errorf(WeightNotProvidedError, ig.Name, name, route.ServiceName) } weight += int(*route.Weight) } if weight != 100 { - return fmt.Errorf("InferenceGraph[%s] Node[%s] splitter node: the sum of traffic weights for all routing targets should be 100", ig.Name, name) + return fmt.Errorf(InvalidWeightError, ig.Name, name) } } } diff --git a/pkg/apis/serving/v1alpha1/inference_graph_validation_test.go b/pkg/apis/serving/v1alpha1/inference_graph_validation_test.go new file mode 100644 index 00000000000..eb56bad1f78 --- /dev/null +++ b/pkg/apis/serving/v1alpha1/inference_graph_validation_test.go @@ -0,0 +1,316 @@ +package v1alpha1 + +import ( + "fmt" + "github.com/golang/protobuf/proto" + "github.com/onsi/gomega" + "github.com/onsi/gomega/types" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "testing" +) + +func makeTestInferenceGraph() InferenceGraph { + ig := InferenceGraph{ + TypeMeta: metav1.TypeMeta{ + Kind: "InferenceGraph", + APIVersion: "v1alpha1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "foo-bar", + }, + Spec: InferenceGraphSpec{}, + } + return ig +} + +func TestInferenceGraph_ValidateCreate(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + ig InferenceGraph + update map[string]string + nodes map[string]InferenceRouter + matcher types.GomegaMatcher + }{ + "simple": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(nil), + }, + "alphanumeric model name": { + ig: makeTestInferenceGraph(), + update: map[string]string{ + name: "Abc-123", + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidGraphNameFormatError, "Abc-123", GraphNameFmt)), + }, + "name starts with number": { + ig: makeTestInferenceGraph(), + update: map[string]string{ + name: "4abc-3", + }, + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidGraphNameFormatError, "4abc-3", GraphNameFmt)), + }, + "name starts with dash": { + ig: makeTestInferenceGraph(), + update: map[string]string{ + name: "-abc-3", + }, + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidGraphNameFormatError, "-abc-3", GraphNameFmt)), + }, + "name ends with dash": { + ig: makeTestInferenceGraph(), + update: map[string]string{ + name: "abc-3-", + }, + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidGraphNameFormatError, "abc-3-", GraphNameFmt)), + }, + "name includes dot": { + ig: makeTestInferenceGraph(), + update: map[string]string{ + name: "abc.123", + }, + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidGraphNameFormatError, "abc.123", GraphNameFmt)), + }, + "name includes spaces": { + ig: makeTestInferenceGraph(), + update: map[string]string{ + name: "abc 123", + }, + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidGraphNameFormatError, "abc 123", GraphNameFmt)), + }, + "without root node": { + ig: makeTestInferenceGraph(), + nodes: make(map[string]InferenceRouter), + matcher: gomega.MatchError(fmt.Errorf(RootNodeNotFoundError)), + }, + "with root node": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(nil), + }, + "invalid weight for splitter": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: { + RouterType: "Splitter", + Steps: []InferenceStep{ + { + Weight: proto.Int64(80), + InferenceTarget: InferenceTarget{ + NodeName: "test", + }, + }, + { + Weight: proto.Int64(30), + InferenceTarget: InferenceTarget{ + ServiceURL: "http://foo-bar.local/", + }, + }, + }, + }, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidWeightError, "foo-bar", GraphRootNodeName)), + }, + "weight missing in splitter": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: { + RouterType: "Splitter", + Steps: []InferenceStep{ + { + InferenceTarget: InferenceTarget{ + ServiceName: "test", + }, + }, + }, + }, + }, + matcher: gomega.MatchError(fmt.Errorf(WeightNotProvidedError, "foo-bar", GraphRootNodeName, "test")), + }, + "simple splitter": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: { + RouterType: "Splitter", + Steps: []InferenceStep{ + { + Weight: proto.Int64(80), + InferenceTarget: InferenceTarget{ + ServiceName: "service1", + }, + }, + { + Weight: proto.Int64(20), + InferenceTarget: InferenceTarget{ + ServiceName: "service2", + }, + }, + }, + }, + }, + matcher: gomega.MatchError(nil), + }, + "step inference target not provided": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: { + RouterType: "Splitter", + Steps: []InferenceStep{ + { + Weight: proto.Int64(100), + }, + }, + }, + }, + matcher: gomega.MatchError(fmt.Errorf(TargetNotProvidedError, 0, "", GraphRootNodeName, "foo-bar")), + }, + "invalid inference graph target": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: { + RouterType: "Splitter", + Steps: []InferenceStep{ + { + Weight: proto.Int64(100), + InferenceTarget: InferenceTarget{ + ServiceName: "service", + NodeName: "test", + }, + }, + }, + }, + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidTargetError, 0, "", GraphRootNodeName, "foo-bar")), + }, + "duplicate step name": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: { + RouterType: "Splitter", + Steps: []InferenceStep{ + { + StepName: "step1", + Weight: proto.Int64(80), + InferenceTarget: InferenceTarget{ + ServiceName: "service1", + }, + }, + { + StepName: "step1", + Weight: proto.Int64(20), + InferenceTarget: InferenceTarget{ + ServiceName: "service2", + }, + }, + }, + }, + }, + matcher: gomega.MatchError(fmt.Errorf(DuplicateStepNameError, GraphRootNodeName, "foo-bar", "step1")), + }, + } + + for testName, scenario := range scenarios { + t.Run(testName, func(t *testing.T) { + ig := &scenario.ig + for igField, value := range scenario.update { + ig.update(igField, value) + } + ig.Spec.Nodes = scenario.nodes + res := scenario.ig.ValidateCreate() + if !g.Expect(gomega.MatchError(res)).To(gomega.Equal(scenario.matcher)) { + t.Errorf("got %t, want %t", res, scenario.matcher) + } + }) + } +} + +func TestInferenceGraph_ValidateUpdate(t *testing.T) { + g := gomega.NewGomegaWithT(t) + temptIg := makeTestTrainModel() + old := temptIg.DeepCopyObject() + scenarios := map[string]struct { + ig InferenceGraph + update map[string]string + nodes map[string]InferenceRouter + matcher types.GomegaMatcher + }{ + "no change": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(nil), + }, + } + + for testName, scenario := range scenarios { + t.Run(testName, func(t *testing.T) { + ig := &scenario.ig + for igField, value := range scenario.update { + ig.update(igField, value) + } + ig.Spec.Nodes = scenario.nodes + res := scenario.ig.ValidateUpdate(old) + if !g.Expect(gomega.MatchError(res)).To(gomega.Equal(scenario.matcher)) { + t.Errorf("got %t, want %t", res, scenario.matcher) + } + }) + } +} + +func TestInferenceGraph_ValidateDelete(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + ig InferenceGraph + update map[string]string + nodes map[string]InferenceRouter + matcher types.GomegaMatcher + }{ + "simple": { + ig: makeTestInferenceGraph(), + nodes: map[string]InferenceRouter{ + GraphRootNodeName: {}, + }, + matcher: gomega.MatchError(nil), + }, + } + + for testName, scenario := range scenarios { + t.Run(testName, func(t *testing.T) { + ig := &scenario.ig + for igField, value := range scenario.update { + ig.update(igField, value) + } + ig.Spec.Nodes = scenario.nodes + res := scenario.ig.ValidateDelete() + if !g.Expect(gomega.MatchError(res)).To(gomega.Equal(scenario.matcher)) { + t.Errorf("got %t, want %t", res, scenario.matcher) + } + }) + } +} + +func (ig *InferenceGraph) update(igField string, value string) { + if igField == "Name" { + ig.Name = value + } +} diff --git a/pkg/apis/serving/v1alpha1/servingruntime_types_test.go b/pkg/apis/serving/v1alpha1/servingruntime_types_test.go index 8f28382910e..ff7b7e1c8a2 100644 --- a/pkg/apis/serving/v1alpha1/servingruntime_types_test.go +++ b/pkg/apis/serving/v1alpha1/servingruntime_types_test.go @@ -18,6 +18,8 @@ package v1alpha1 import ( "fmt" + "github.com/golang/protobuf/proto" + "github.com/kserve/kserve/pkg/constants" "testing" v1 "k8s.io/api/core/v1" @@ -79,3 +81,373 @@ func TestMarshalServingRuntime(t *testing.T) { } fmt.Println(string(b)) } + +func TestServingRuntimeSpec_IsDisabled(t *testing.T) { + endpoint := "endpoint" + version := "1.0" + + scenarios := map[string]struct { + spec ServingRuntimeSpec + res bool + }{ + "default behaviour": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + res: false, + }, + "specified explicitly": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + Disabled: proto.Bool(true), + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + res: true, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.IsDisabled() + if res != scenario.res { + fmt.Println(fmt.Errorf("Expected %t, got %t", scenario.res, res)) + } + }) + } +} + +func TestServingRuntimeSpec_IsMultiModelRuntime(t *testing.T) { + endpoint := "endpoint" + version := "1.0" + + scenarios := map[string]struct { + spec ServingRuntimeSpec + res bool + }{ + "default behaviour": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + res: false, + }, + "multimodel specified explicitly": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + MultiModel: proto.Bool(true), + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + res: true, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.IsMultiModelRuntime() + if res != scenario.res { + fmt.Println(fmt.Errorf("Expected %t, got %t", scenario.res, res)) + } + }) + } +} + +func TestServingRuntimeSpec_IsProtocolVersionSupported(t *testing.T) { + endpoint := "endpoint" + version := "1.0" + + scenarios := map[string]struct { + spec ServingRuntimeSpec + protocolVersion constants.InferenceServiceProtocol + res bool + }{ + "v1 protocol": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + protocolVersion: constants.ProtocolV1, + res: true, + }, + "v2 protocol": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + MultiModel: proto.Bool(true), + Disabled: proto.Bool(true), + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + protocolVersion: constants.ProtocolV2, + res: false, + }, + "protocols specified": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + ProtocolVersions: []constants.InferenceServiceProtocol{ + constants.ProtocolV2, + constants.ProtocolGRPCV2, + }, + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + protocolVersion: constants.ProtocolGRPCV2, + res: true, + }, + "unsupported protocol": { + spec: ServingRuntimeSpec{ + GrpcDataEndpoint: &endpoint, + ProtocolVersions: []constants.InferenceServiceProtocol{ + constants.ProtocolV2, + constants.ProtocolGRPCV2, + }, + ServingRuntimePodSpec: ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Args: []string{"arg1", "arg2"}, + Command: []string{"command", "command2"}, + Env: []v1.EnvVar{ + {Name: "name", Value: "value"}, + { + Name: "fromSecret", + ValueFrom: &v1.EnvVarSource{ + SecretKeyRef: &v1.SecretKeySelector{Key: "mykey"}, + }, + }, + }, + Image: "image", + Name: "name", + ImagePullPolicy: "IfNotPresent", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200Mi"), + }, + }, + }, + }, + }, + SupportedModelFormats: []SupportedModelFormat{ + { + Name: "name", + Version: &version, + }, + }, + }, + protocolVersion: constants.ProtocolV1, + res: false, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.IsProtocolVersionSupported(scenario.protocolVersion) + if res != scenario.res { + fmt.Println(fmt.Errorf("Expected %t, got %t", scenario.res, res)) + } + }) + } +} diff --git a/pkg/apis/serving/v1alpha1/trained_model_status_test.go b/pkg/apis/serving/v1alpha1/trained_model_status_test.go new file mode 100644 index 00000000000..b5e2c72248e --- /dev/null +++ b/pkg/apis/serving/v1alpha1/trained_model_status_test.go @@ -0,0 +1,360 @@ +package v1alpha1 + +import ( + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/onsi/gomega" + v1 "k8s.io/api/core/v1" + "knative.dev/pkg/apis" + duckv1 "knative.dev/pkg/apis/duck/v1" + knservingv1 "knative.dev/serving/pkg/apis/serving/v1" + "testing" +) + +func TestTrainedModelStatus_IsReady(t *testing.T) { + cases := []struct { + name string + ServiceStatus TrainedModelStatus + isReady bool + }{{ + name: "empty status should not be ready", + ServiceStatus: TrainedModelStatus{}, + isReady: false, + }, { + name: "Different condition type should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: "Foo", + Status: v1.ConditionTrue, + }}, + }, + }, + isReady: false, + }, { + name: "False condition status should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionFalse, + }}, + }, + }, + isReady: false, + }, { + name: "Unknown condition status should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionUnknown, + }}, + }, + }, + isReady: false, + }, { + name: "Missing condition status should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ConfigurationConditionReady, + }}, + }, + }, + isReady: false, + }, { + name: "True condition status should be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "ConfigurationsReady", + Status: v1.ConditionTrue, + }, + { + Type: "RoutesReady", + Status: v1.ConditionTrue, + }, + { + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionTrue, + }, + }, + }, + }, + isReady: true, + }, { + name: "Multiple conditions with ready status false should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Foo", + Status: v1.ConditionTrue, + }, + { + Type: knservingv1.ConfigurationConditionReady, + Status: v1.ConditionFalse, + }, + }, + }, + }, + isReady: false, + }} + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if e, a := tc.isReady, tc.ServiceStatus.IsReady(); e != a { + t.Errorf("%q expected: %v got: %v conditions: %v", tc.name, e, a, tc.ServiceStatus.Conditions) + } + }) + } +} + +func TestTrainedModelStatus_GetCondition(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + ServiceStatus TrainedModelStatus + Condition apis.ConditionType + matcher *apis.Condition + }{{ + name: "Empty status should return nil", + ServiceStatus: TrainedModelStatus{}, + Condition: knservingv1.ServiceConditionReady, + matcher: nil, + }, { + name: "Get custom condition", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: "Foo", + Status: v1.ConditionFalse, + }}, + }, + }, + Condition: "Foo", + matcher: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionFalse, + }, + }, { + name: "Get Ready condition", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionUnknown, + }}, + }, + }, + Condition: knservingv1.ServiceConditionReady, + matcher: &apis.Condition{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionUnknown, + }, + }} + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + res := tc.ServiceStatus.GetCondition(tc.Condition) + g.Expect(res).Should(gomega.Equal(tc.matcher)) + }) + } +} + +func TestTrainedModelStatus_IsConditionReady(t *testing.T) { + cases := []struct { + name string + ServiceStatus TrainedModelStatus + Condition apis.ConditionType + isReady bool + }{{ + name: "empty status should not be ready", + ServiceStatus: TrainedModelStatus{}, + Condition: knservingv1.ServiceConditionReady, + isReady: false, + }, { + name: "Different condition type should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: "Foo", + Status: v1.ConditionTrue, + }}, + }, + }, + Condition: "Bar", + isReady: false, + }, { + name: "False condition status should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionFalse, + }}, + }, + }, + Condition: knservingv1.ServiceConditionReady, + isReady: false, + }, { + name: "Unknown condition status should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionUnknown, + }}, + }, + }, + Condition: knservingv1.ServiceConditionReady, + isReady: false, + }, { + name: "Missing condition status should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ConfigurationConditionReady, + }}, + }, + }, + Condition: knservingv1.ServiceConditionReady, + isReady: false, + }, { + name: "True condition status should be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "ConfigurationsReady", + Status: v1.ConditionTrue, + }, + { + Type: "RoutesReady", + Status: v1.ConditionTrue, + }, + { + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionTrue, + }, + }, + }, + }, + Condition: knservingv1.ServiceConditionReady, + isReady: true, + }, { + name: "Multiple conditions with ready status false should not be ready", + ServiceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Foo", + Status: v1.ConditionTrue, + }, + { + Type: knservingv1.ConfigurationConditionReady, + Status: v1.ConditionFalse, + }, + }, + }, + }, + Condition: knservingv1.ConfigurationConditionReady, + isReady: false, + }} + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if e, a := tc.isReady, tc.ServiceStatus.IsConditionReady(tc.Condition); e != a { + t.Errorf("%q expected: %v got: %v conditions: %v", tc.name, e, a, tc.ServiceStatus.Conditions) + } + }) + } +} + +func TestTrainedModelStatus_SetCondition(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + serviceStatus TrainedModelStatus + condition *apis.Condition + conditionType apis.ConditionType + expected *apis.Condition + }{{ + name: "set condition on empty status", + serviceStatus: TrainedModelStatus{}, + condition: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionTrue, + }, + conditionType: "Foo", + expected: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionTrue, + }, + }, { + name: "modify existing condition", + serviceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: "Foo", + Status: v1.ConditionTrue, + }}, + }, + }, + condition: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionFalse, + }, + conditionType: "Foo", + expected: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionFalse, + }, + }, { + name: "set condition unknown", + serviceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: "Foo", + Status: v1.ConditionFalse, + }}, + }, + }, + condition: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionUnknown, + Reason: "For testing purpose", + }, + conditionType: "Foo", + expected: &apis.Condition{ + Type: "Foo", + Status: v1.ConditionUnknown, + Reason: "For testing purpose", + }, + }, { + name: "condition is nil", + serviceStatus: TrainedModelStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{{ + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionTrue, + }}, + }, + }, + condition: nil, + conditionType: "Foo", + expected: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.serviceStatus.SetCondition(tc.conditionType, tc.condition) + res := tc.serviceStatus.GetCondition(tc.conditionType) + g.Expect(cmp.Equal(res, tc.expected, cmpopts.IgnoreFields(apis.Condition{}, "LastTransitionTime", "Severity"))).To(gomega.BeTrue()) + + }) + } +} diff --git a/pkg/apis/serving/v1alpha1/trained_model_test.go b/pkg/apis/serving/v1alpha1/trained_model_test.go new file mode 100644 index 00000000000..38806d0ebb4 --- /dev/null +++ b/pkg/apis/serving/v1alpha1/trained_model_test.go @@ -0,0 +1,61 @@ +package v1alpha1 + +import ( + "fmt" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "testing" +) + +func TestTrainedModelList_TotalRequestedMemory(t *testing.T) { + list := TrainedModelList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []TrainedModel{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-model-1", + }, + Spec: TrainedModelSpec{ + Model: ModelSpec{ + StorageURI: "http://example.com/", + Framework: "sklearn", + Memory: resource.MustParse("1Gi"), + }, + }, + }, + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-model-2", + }, + Spec: TrainedModelSpec{ + Model: ModelSpec{ + StorageURI: "http://example.com/", + Framework: "sklearn", + Memory: resource.MustParse("1Gi"), + }, + }, + }, + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-model-3", + }, + Spec: TrainedModelSpec{ + Model: ModelSpec{ + StorageURI: "http://example.com/", + Framework: "sklearn", + Memory: resource.MustParse("1Gi"), + }, + }, + }, + }, + } + res := list.TotalRequestedMemory() + expected := resource.MustParse("3Gi") + if res != expected { + fmt.Println(fmt.Errorf("expected %v got %v", expected, res)) + } +} diff --git a/pkg/apis/serving/v1beta1/component_test.go b/pkg/apis/serving/v1beta1/component_test.go index 3ca6a4c51e1..f06be1298dd 100644 --- a/pkg/apis/serving/v1beta1/component_test.go +++ b/pkg/apis/serving/v1beta1/component_test.go @@ -17,10 +17,13 @@ limitations under the License. package v1beta1 import ( + "fmt" + "strings" + "testing" + "github.com/golang/protobuf/proto" "github.com/onsi/gomega" "github.com/onsi/gomega/types" - "testing" ) func TestComponentExtensionSpec_Validate(t *testing.T) { @@ -54,3 +57,136 @@ func TestComponentExtensionSpec_Validate(t *testing.T) { }) } } + +func TestComponentExtensionSpec_validateStorageSpec(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + spec *StorageSpec + storageUri *string + matcher types.GomegaMatcher + }{ + "ValidStoragespec": { + spec: &StorageSpec{ + Parameters: &map[string]string{ + "type": "s3", + }, + }, + storageUri: nil, + matcher: gomega.BeNil(), + }, + "ValidStoragespecWithoutParameters": { + spec: &StorageSpec{}, + storageUri: nil, + matcher: gomega.BeNil(), + }, + "ValidStoragespecWithStorageURI": { + spec: &StorageSpec{ + Parameters: &map[string]string{ + "type": "s3", + }, + }, + storageUri: proto.String("s3://test/model"), + matcher: gomega.BeNil(), + }, + "StorageSpecWithInvalidStorageURI": { + spec: &StorageSpec{ + Parameters: &map[string]string{ + "type": "gs", + }, + }, + storageUri: proto.String("gs://test/model"), + matcher: gomega.MatchError(fmt.Errorf(UnsupportedStorageURIFormatError, strings.Join(SupportedStorageSpecURIPrefixList, ", "), "gs://test/model")), + }, + "InvalidStoragespec": { + spec: &StorageSpec{ + Parameters: &map[string]string{ + "type": "gs", + }, + }, + storageUri: nil, + matcher: gomega.MatchError(fmt.Errorf(UnsupportedStorageSpecFormatError, strings.Join(SupportedStorageSpecURIPrefixList, ", "), "gs")), + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + g.Expect(validateStorageSpec(scenario.spec, scenario.storageUri)).To(scenario.matcher) + }) + } +} + +func TestComponentExtensionSpec_validateLogger(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + logger *LoggerSpec + matcher types.GomegaMatcher + }{ + "LoggerWithLogAllMode": { + logger: &LoggerSpec{ + Mode: LogAll, + }, + matcher: gomega.BeNil(), + }, + "LoggerWithLogRequestMode": { + logger: &LoggerSpec{ + Mode: LogRequest, + }, + matcher: gomega.BeNil(), + }, + "LoggerWithLogResponseMode": { + logger: &LoggerSpec{ + Mode: LogResponse, + }, + matcher: gomega.BeNil(), + }, + "InvalidLoggerMode": { + logger: &LoggerSpec{ + Mode: "InvalidMode", + }, + matcher: gomega.MatchError(fmt.Errorf(InvalidLoggerType)), + }, + "LoggerIsNil": { + logger: nil, + matcher: gomega.BeNil(), + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + g.Expect(validateLogger(scenario.logger)).To(scenario.matcher) + }) + } +} + +func TestFirstNonNilComponent(t *testing.T) { + g := gomega.NewGomegaWithT(t) + spec := PredictorSpec{ + SKLearn: &SKLearnSpec{}, + } + scenarios := map[string]struct { + components []ComponentImplementation + matcher types.GomegaMatcher + }{ + "WithNonNilComponent": { + components: []ComponentImplementation{ + spec.PyTorch, + spec.LightGBM, + spec.SKLearn, + spec.Tensorflow, + }, + matcher: gomega.Equal(spec.SKLearn), + }, + "NoNonNilComponents": { + components: []ComponentImplementation{ + spec.PyTorch, + spec.LightGBM, + spec.Tensorflow, + spec.PMML, + }, + matcher: gomega.BeNil(), + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + g.Expect(FirstNonNilComponent(scenario.components)).To(scenario.matcher) + }) + } +} diff --git a/pkg/apis/serving/v1beta1/configmap_test.go b/pkg/apis/serving/v1beta1/configmap_test.go new file mode 100644 index 00000000000..e66a5fcceed --- /dev/null +++ b/pkg/apis/serving/v1beta1/configmap_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2022 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1beta1 + +import ( + ctx "context" + logger "log" + "testing" + + "github.com/kserve/kserve/pkg/constants" + "github.com/onsi/gomega" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + fakeclient "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func createFakeClient() client.WithWatch { + clientBuilder := fakeclient.NewClientBuilder() + fakeClient := clientBuilder.Build() + configMap := &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{ + Kind: "ConfigMap", + APIVersion: "", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceConfigMapName, + Namespace: constants.KServeNamespace, + }, + Immutable: nil, + Data: map[string]string{}, + BinaryData: map[string][]byte{ + ExplainerConfigKeyName: []byte(`{ │ + "alibi": { │ + "image" : "kserve/alibi-explainer", │ + "defaultImageVersion": "latest" │ + }, │ + "aix": { │ + "image" : "kserve/aix-explainer", │ + "defaultImageVersion": "latest" │ + }, │ + "art": { │ + "image" : "kserve/art-explainer", │ + "defaultImageVersion": "latest" │ + } │ + }`), + IngressConfigKeyName: []byte(`{ │ + "ingressGateway" : "knative-serving/knative-ingress-gateway", │ + "ingressService" : "istio-ingressgateway.istio-system.svc.cluster.local", │ + "localGateway" : "knative-serving/knative-local-gateway", │ + "localGatewayService" : "knative-local-gateway.istio-system.svc.cluster.local", │ + "ingressDomain" : "example.com", │ + "ingressClassName" : "istio", │ + "domainTemplate": "{{ .Name }}-{{ .Namespace }}.{{ .IngressDomain }}", │ + "urlScheme": "http" │ + }`), + DeployConfigName: []byte(`{ │ + "defaultDeploymentMode": "Serverless" │ + }`), + }, + } + err := fakeClient.Create(ctx.TODO(), configMap) + if err != nil { + logger.Fatalf("Unable to create configmap: %v", err) + } + return fakeClient +} + +func TestNewInferenceServiceConfig(t *testing.T) { + g := gomega.NewGomegaWithT(t) + fakeClient := createFakeClient() + + isvcConfig, err := NewInferenceServicesConfig(fakeClient) + g.Expect(err).Should(gomega.BeNil()) + g.Expect(isvcConfig).ShouldNot(gomega.BeNil()) +} + +func TestNewIngressConfig(t *testing.T) { + g := gomega.NewGomegaWithT(t) + fakeClient := createFakeClient() + + ingressCfg, err := NewIngressConfig(fakeClient) + g.Expect(err).Should(gomega.BeNil()) + g.Expect(ingressCfg).ShouldNot(gomega.BeNil()) +} + +func TestNewDeployConfig(t *testing.T) { + g := gomega.NewGomegaWithT(t) + fakeClient := createFakeClient() + + deployConfig, err := NewDeployConfig(fakeClient) + g.Expect(err).Should(gomega.BeNil()) + g.Expect(deployConfig).ShouldNot(gomega.BeNil()) +} diff --git a/pkg/apis/serving/v1beta1/inference_service_defaults_test.go b/pkg/apis/serving/v1beta1/inference_service_defaults_test.go index 8cbe695b1f7..a48b696ebdf 100644 --- a/pkg/apis/serving/v1beta1/inference_service_defaults_test.go +++ b/pkg/apis/serving/v1beta1/inference_service_defaults_test.go @@ -17,12 +17,14 @@ limitations under the License. package v1beta1 import ( + "strconv" "testing" "github.com/golang/protobuf/proto" "github.com/kserve/kserve/pkg/constants" "github.com/onsi/gomega" + "github.com/onsi/gomega/types" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -82,8 +84,8 @@ func TestInferenceServiceDefaults(t *testing.T) { g.Expect(*&isvc.Spec.Predictor.Tensorflow).To(gomega.BeNil()) g.Expect(*&isvc.Spec.Predictor.Model).NotTo(gomega.BeNil()) + g.Expect(isvc.Spec.Predictor.Model).NotTo(gomega.BeNil()) g.Expect(isvc.Spec.Transformer.PodSpec.Containers[0].Resources).To(gomega.Equal(resources)) - g.Expect(*isvc.Spec.Explainer.Alibi.RuntimeVersion).To(gomega.Equal("v0.4.0")) g.Expect(isvc.Spec.Explainer.Alibi.Resources).To(gomega.Equal(resources)) } @@ -158,3 +160,346 @@ func TestInferenceServiceDefaultsModelMeshAnnotation(t *testing.T) { g.Expect(isvc.Spec.Predictor.Model).To(gomega.BeNil()) g.Expect(isvc.Spec.Predictor.Tensorflow).ToNot(gomega.BeNil()) } + +func TestRuntimeDefaults(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + deployConfig := &DeployConfig{ + DefaultDeploymentMode: "Serverless", + } + scenarios := map[string]struct { + config *InferenceServicesConfig + isvc InferenceService + runtime string + matcher types.GomegaMatcher + }{ + "PyTorch": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + PyTorch: &TorchServeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + runtime: constants.TorchServe, + matcher: gomega.Equal(constants.ProtocolV1), + }, + "Triton": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + Triton: &TritonSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + runtime: constants.TritonServer, + matcher: gomega.Equal(constants.ProtocolV2), + }, + "MlServer": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + runtime: constants.MLServer, + matcher: gomega.Equal(constants.ProtocolV2), + }, + } + for name, scenario := range scenarios { + scenario.isvc.DefaultInferenceService(scenario.config, deployConfig) + scenario.isvc.Spec.Predictor.Model.Runtime = &scenario.runtime + scenario.isvc.SetRuntimeDefaults() + g.Expect(scenario.isvc.Spec.Predictor.Model).ToNot(gomega.BeNil()) + switch name { + + case "PyTorch": + g.Expect(scenario.isvc.Spec.Predictor.PyTorch).To(gomega.BeNil()) + + case "Triton": + g.Expect(scenario.isvc.Spec.Predictor.Triton).To(gomega.BeNil()) + + case "MlServer": + g.Expect(scenario.isvc.Spec.Predictor.XGBoost).To(gomega.BeNil()) + } + g.Expect(*scenario.isvc.Spec.Predictor.Model.ProtocolVersion).To(scenario.matcher) + } +} + +func TestTorchServeDefaults(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + deployConfig := &DeployConfig{ + DefaultDeploymentMode: "Serverless", + } + protocolVersion := constants.ProtocolV2 + scenarios := map[string]struct { + config *InferenceServicesConfig + isvc InferenceService + matcher types.GomegaMatcher + }{ + "pytorch with protocol version 2": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + PyTorch: &TorchServeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + ProtocolVersion: &protocolVersion, + }, + }, + }, + }, + }, + matcher: gomega.HaveKeyWithValue(constants.ServiceEnvelope, constants.ServiceEnvelopeKServeV2), + }, + "pytorch with labels": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + Labels: map[string]string{ + "Purpose": "Testing", + }, + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + PyTorch: &TorchServeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + matcher: gomega.HaveKeyWithValue("Purpose", "Testing"), + }, + } + runtime := constants.TorchServe + for _, scenario := range scenarios { + scenario.isvc.DefaultInferenceService(scenario.config, deployConfig) + scenario.isvc.Spec.Predictor.Model.Runtime = &runtime + scenario.isvc.SetTorchServeDefaults() + g.Expect(scenario.isvc.Spec.Predictor.Model).ToNot(gomega.BeNil()) + g.Expect(scenario.isvc.Spec.Predictor.PyTorch).To(gomega.BeNil()) + g.Expect(scenario.isvc.ObjectMeta.Labels).To(scenario.matcher) + } +} + +func TestSetTritonDefaults(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + deployConfig := &DeployConfig{ + DefaultDeploymentMode: "Serverless", + } + scenarios := map[string]struct { + config *InferenceServicesConfig + isvc InferenceService + matcher types.GomegaMatcher + }{ + "Storage URI is nil": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + Triton: &TritonSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + }, + }, + matcher: gomega.ContainElement("--model-control-mode=explicit"), + }, + } + runtime := constants.TritonServer + for _, scenario := range scenarios { + scenario.isvc.DefaultInferenceService(scenario.config, deployConfig) + scenario.isvc.Spec.Predictor.Model.Runtime = &runtime + scenario.isvc.SetTritonDefaults() + g.Expect(scenario.isvc.Spec.Predictor.Model).ToNot(gomega.BeNil()) + g.Expect(scenario.isvc.Spec.Predictor.Triton).To(gomega.BeNil()) + g.Expect(scenario.isvc.Spec.Predictor.Model.Args).To(scenario.matcher) + } +} + +func TestMlServerDefaults(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + deployConfig := &DeployConfig{ + DefaultDeploymentMode: "Serverless", + } + scenarios := map[string]struct { + config *InferenceServicesConfig + isvc InferenceService + matcher map[string]types.GomegaMatcher + }{ + "Storage URI is nil": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + }, + }, + matcher: map[string]types.GomegaMatcher{ + "env": gomega.ContainElement(v1.EnvVar{ + Name: constants.MLServerLoadModelsStartupEnv, + Value: strconv.FormatBool(false), + }), + "protocolVersion": gomega.Equal(constants.ProtocolV2), + "labels": gomega.HaveKeyWithValue(constants.ModelClassLabel, constants.MLServerModelClassSKLearn), + }, + }, + "XGBoost model": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + matcher: map[string]types.GomegaMatcher{ + "env": gomega.ContainElements( + v1.EnvVar{ + Name: constants.MLServerModelNameEnv, + Value: "foo", + }, + v1.EnvVar{ + Name: constants.MLServerModelURIEnv, + Value: constants.DefaultModelLocalMountPath, + }), + "protocolVersion": gomega.Equal(constants.ProtocolV2), + "labels": gomega.HaveKeyWithValue(constants.ModelClassLabel, constants.MLServerModelClassXGBoost), + }, + }, + "LightGBM model": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + matcher: map[string]types.GomegaMatcher{ + "env": gomega.ContainElements( + v1.EnvVar{ + Name: constants.MLServerModelNameEnv, + Value: "foo", + }, + v1.EnvVar{ + Name: constants.MLServerModelURIEnv, + Value: constants.DefaultModelLocalMountPath, + }), + "protocolVersion": gomega.Equal(constants.ProtocolV2), + "labels": gomega.HaveKeyWithValue(constants.ModelClassLabel, constants.MLServerModelClassLightGBM), + }, + }, + "LightGBM model with labels": { + config: &InferenceServicesConfig{}, + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + Labels: map[string]string{ + "Purpose": "Testing", + }, + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + }, + matcher: map[string]types.GomegaMatcher{ + "env": gomega.ContainElements( + v1.EnvVar{ + Name: constants.MLServerModelNameEnv, + Value: "foo", + }, + v1.EnvVar{ + Name: constants.MLServerModelURIEnv, + Value: constants.DefaultModelLocalMountPath, + }), + "protocolVersion": gomega.Equal(constants.ProtocolV2), + "labels": gomega.HaveKeyWithValue("Purpose", "Testing"), + }, + }, + } + runtime := constants.MLServer + for _, scenario := range scenarios { + scenario.isvc.DefaultInferenceService(scenario.config, deployConfig) + scenario.isvc.Spec.Predictor.Model.Runtime = &runtime + scenario.isvc.SetMlServerDefaults() + g.Expect(scenario.isvc.Spec.Predictor.Model).ToNot(gomega.BeNil()) + g.Expect(scenario.isvc.Spec.Predictor.Model.Env).To(scenario.matcher["env"]) + g.Expect(*scenario.isvc.Spec.Predictor.Model.ProtocolVersion).To(scenario.matcher["protocolVersion"]) + g.Expect(scenario.isvc.ObjectMeta.Labels).To(scenario.matcher["labels"]) + } +} diff --git a/pkg/apis/serving/v1beta1/inference_service_status_test.go b/pkg/apis/serving/v1beta1/inference_service_status_test.go index df7c02af2dc..2911110ab3e 100644 --- a/pkg/apis/serving/v1beta1/inference_service_status_test.go +++ b/pkg/apis/serving/v1beta1/inference_service_status_test.go @@ -17,12 +17,21 @@ limitations under the License. package v1beta1 import ( + "github.com/kserve/kserve/pkg/constants" + "github.com/onsi/gomega" + "net/url" + "testing" + "time" + + "github.com/golang/protobuf/proto" + appsv1 "k8s.io/api/apps/v1" "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "knative.dev/pkg/apis" "knative.dev/pkg/apis/duck" duckv1 "knative.dev/pkg/apis/duck/v1" duckv1beta1 "knative.dev/pkg/apis/duck/v1beta1" knservingv1 "knative.dev/serving/pkg/apis/serving/v1" - "testing" ) func TestInferenceServiceDuckType(t *testing.T) { @@ -172,3 +181,1048 @@ func TestInferenceServiceIsReady(t *testing.T) { }) } } + +func TestPropagateRawStatus(t *testing.T) { + deployment := &appsv1.Deployment{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "deployment.kubernetes.io/revision": "1", + }, + }, + Spec: appsv1.DeploymentSpec{}, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentAvailable, + Status: v1.ConditionTrue, + Reason: "MinimumReplicasAvailable", + Message: "Deployment has minimum availability.", + LastTransitionTime: metav1.Time{ + Time: time.Now(), + }, + }, + }, + }, + } + status := &InferenceServiceStatus{ + Status: duckv1.Status{}, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + ModelStatus: ModelStatus{}, + } + parsedUrl, _ := url.Parse("http://test-predictor-default.default.example.com") + url := (*apis.URL)(parsedUrl) + status.PropagateRawStatus(PredictorComponent, deployment, url) + if res := status.IsConditionReady(PredictorReady); !res { + t.Errorf("expected: %v got: %v conditions: %v", true, res, status.Conditions) + } +} + +func TestPropagateStatus(t *testing.T) { + parsedUrl, _ := url.Parse("http://test-predictor-default.default.example.com") + cases := []struct { + name string + ServiceStatus knservingv1.ServiceStatus + status InferenceServiceStatus + isReady bool + }{ + { + name: "Status with Traffic Routing for Latest Revision", + ServiceStatus: knservingv1.ServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Foo", + Status: v1.ConditionTrue, + }, + { + Type: "RoutesReady", + Status: v1.ConditionTrue, + }, + { + Type: "ConfigurationsReady", + Status: v1.ConditionTrue, + }, + { + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionTrue, + }, + }, + }, + ConfigurationStatusFields: knservingv1.ConfigurationStatusFields{ + LatestReadyRevisionName: "test-predictor-default-0001", + }, + RouteStatusFields: knservingv1.RouteStatusFields{ + Traffic: []knservingv1.TrafficTarget{ + { + RevisionName: "test-predictor-default-0001", + Percent: proto.Int64(100), + LatestRevision: proto.Bool(true), + }, + }, + Address: &duckv1.Addressable{}, + URL: (*apis.URL)(parsedUrl), + }, + }, + status: InferenceServiceStatus{}, + isReady: true, + }, + { + name: "Status with Traffic Routing for Rolledout Revision", + ServiceStatus: knservingv1.ServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Foo", + Status: v1.ConditionTrue, + }, + { + Type: "RoutesReady", + Status: v1.ConditionTrue, + }, + { + Type: "ConfigurationsReady", + Status: v1.ConditionTrue, + }, + { + Type: knservingv1.ServiceConditionReady, + Status: v1.ConditionTrue, + }, + }, + }, + ConfigurationStatusFields: knservingv1.ConfigurationStatusFields{ + LatestReadyRevisionName: "test-predictor-default-0001", + LatestCreatedRevisionName: "test-predictor-default-0001", + }, + RouteStatusFields: knservingv1.RouteStatusFields{ + Traffic: []knservingv1.TrafficTarget{ + { + RevisionName: "test-predictor-default-0001", + Percent: proto.Int64(90), + LatestRevision: proto.Bool(true), + }, + }, + Address: &duckv1.Addressable{}, + URL: (*apis.URL)(parsedUrl), + }, + }, + status: InferenceServiceStatus{ + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + }, + isReady: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.status.PropagateStatus(PredictorComponent, &tc.ServiceStatus) + if e, a := tc.isReady, tc.status.IsConditionReady(PredictorReady); e != a { + t.Errorf("%q expected: %v got: %v conditions: %v", tc.name, e, a, tc.status.Conditions) + } + if e, a := tc.status.Components[PredictorComponent].Traffic[0], tc.ServiceStatus.Traffic[0]; e != a { + t.Errorf("%q expected: %v got: %v", tc.name, e, a) + } + }) + } +} + +func TestInferenceServiceStatus_PropagateModelStatus(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + isvcStatus *InferenceServiceStatus + statusSpec ComponentStatusSpec + podList *v1.PodList + rawDeployment bool + expectedRevisionStates *ModelRevisionStates + expectedTransitionStatus TransitionStatus + expectedFailureInfo *FailureInfo + }{ + "pod list is empty": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{}, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: Pending, + }, + expectedTransitionStatus: InProgress, + expectedFailureInfo: nil, + }, + "kserve container in pending state": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.InferenceServiceContainerName, + State: v1.ContainerState{}, + LastTerminationState: v1.ContainerState{}, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: Pending, + }, + expectedTransitionStatus: InProgress, + expectedFailureInfo: nil, + }, + "kserve container failed due to an error": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.InferenceServiceContainerName, + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + ExitCode: 1, + Reason: constants.StateReasonError, + Message: "For testing", + }, + }, + LastTerminationState: v1.ContainerState{}, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: FailedToLoad, + }, + expectedTransitionStatus: BlockedByFailedLoad, + expectedFailureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + }, + "kserve container failed due to crash loopBackOff": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.InferenceServiceContainerName, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: constants.StateReasonCrashLoopBackOff, + Message: "For testing", + }, + }, + LastTerminationState: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + Reason: constants.StateReasonCrashLoopBackOff, + Message: "For testing", + ExitCode: 1, + }, + }, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: FailedToLoad, + }, + expectedTransitionStatus: BlockedByFailedLoad, + expectedFailureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + }, + "storage initializer failed due to an error": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.StorageInitializerContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + InitContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.StorageInitializerContainerName, + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + ExitCode: 1, + Reason: constants.StateReasonError, + Message: "For testing", + }, + }, + LastTerminationState: v1.ContainerState{}, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: FailedToLoad, + }, + expectedTransitionStatus: BlockedByFailedLoad, + expectedFailureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + }, + "storage initializer failed due to crash loopBackOff": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.StorageInitializerContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + InitContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.StorageInitializerContainerName, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: constants.StateReasonCrashLoopBackOff, + Message: "For testing", + }, + }, + LastTerminationState: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + Reason: constants.StateReasonCrashLoopBackOff, + Message: "For testing", + ExitCode: 1, + }, + }, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: FailedToLoad, + }, + expectedTransitionStatus: BlockedByFailedLoad, + expectedFailureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + }, + "storage initializer in running state": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.StorageInitializerContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + InitContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.StorageInitializerContainerName, + State: v1.ContainerState{ + Running: &v1.ContainerStateRunning{ + StartedAt: metav1.Time{}, + }, + }, + LastTerminationState: v1.ContainerState{}, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: proto.Bool(true), + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: "", + TargetModelState: Loading, + }, + expectedTransitionStatus: InProgress, + expectedFailureInfo: nil, + }, + "kserve container is ready": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionTrue, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "test-predictor-default-0001", + LatestCreatedRevision: "test-predictor-default-0001", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "test-predictor-default-0001", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.InferenceServiceContainerName, + State: v1.ContainerState{ + Running: &v1.ContainerStateRunning{}, + }, + Ready: true, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: proto.Bool(true), + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: Loaded, + TargetModelState: Loaded, + }, + expectedTransitionStatus: UpToDate, + expectedFailureInfo: nil, + }, + "raw deployment is ready": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionTrue, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "test-predictor-default-0001", + LatestCreatedRevision: "test-predictor-default-0001", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "test-predictor-default-0001", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceContainerName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + Name: constants.InferenceServiceContainerName, + State: v1.ContainerState{ + Running: &v1.ContainerStateRunning{}, + }, + Ready: true, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: proto.Bool(true), + }, + }, + }, + }, + }, + }, + rawDeployment: true, + expectedRevisionStates: &ModelRevisionStates{ + ActiveModelState: Loaded, + TargetModelState: Loaded, + }, + expectedTransitionStatus: UpToDate, + expectedFailureInfo: nil, + }, + "skip containers other than kserve": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-container", + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + Name: "test-container", + State: v1.ContainerState{}, + LastTerminationState: v1.ContainerState{}, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: nil, + expectedTransitionStatus: "", + expectedFailureInfo: nil, + }, + "skip initcontainers other than storage initializer": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + statusSpec: ComponentStatusSpec{ + LatestReadyRevision: "", + LatestCreatedRevision: "", + PreviousRolledoutRevision: "", + LatestRolledoutRevision: "", + Traffic: nil, + URL: nil, + RestURL: nil, + GrpcURL: nil, + Address: nil, + }, + podList: &v1.PodList{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Items: []v1.Pod{ + { + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-container", + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{ + InitContainerStatuses: []v1.ContainerStatus{ + { + Name: "test-container", + State: v1.ContainerState{}, + LastTerminationState: v1.ContainerState{}, + Ready: false, + RestartCount: 0, + Image: "", + ImageID: "", + ContainerID: "", + Started: nil, + }, + }, + }, + }, + }, + }, + rawDeployment: false, + expectedRevisionStates: nil, + expectedTransitionStatus: "", + expectedFailureInfo: nil, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + scenario.isvcStatus.PropagateModelStatus(scenario.statusSpec, scenario.podList, scenario.rawDeployment) + + g.Expect(scenario.isvcStatus.ModelStatus.ModelRevisionStates).To(gomega.Equal(scenario.expectedRevisionStates)) + g.Expect(scenario.isvcStatus.ModelStatus.TransitionStatus).To(gomega.Equal(scenario.expectedTransitionStatus)) + g.Expect(scenario.isvcStatus.ModelStatus.LastFailureInfo).To(gomega.Equal(scenario.expectedFailureInfo)) + }) + } +} + +func TestInferenceServiceStatus_UpdateModelRevisionStates(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + isvcStatus *InferenceServiceStatus + transitionStatus TransitionStatus + failureInfo *FailureInfo + expected ModelStatus + }{ + "simple": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + transitionStatus: InProgress, + failureInfo: nil, + expected: ModelStatus{ + TransitionStatus: InProgress, + LastFailureInfo: nil, + }, + }, + "invalid spec with nil modelRevisionStates": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{}, + }, + transitionStatus: InvalidSpec, + failureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + expected: ModelStatus{ + TransitionStatus: InvalidSpec, + ModelRevisionStates: &ModelRevisionStates{TargetModelState: FailedToLoad}, + LastFailureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + }, + }, + "invalid spec with modelRevisionStates": { + isvcStatus: &InferenceServiceStatus{ + Status: duckv1.Status{ + Conditions: duckv1.Conditions{ + { + Type: "Ready", + Status: v1.ConditionFalse, + }, + }, + }, + Address: &duckv1.Addressable{}, + URL: &apis.URL{}, + Components: map[ComponentType]ComponentStatusSpec{ + PredictorComponent: { + LatestRolledoutRevision: "test-predictor-default-0001", + }, + }, + ModelStatus: ModelStatus{ + ModelRevisionStates: &ModelRevisionStates{TargetModelState: Loading}, + }, + }, + transitionStatus: InvalidSpec, + failureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + expected: ModelStatus{ + TransitionStatus: InvalidSpec, + ModelRevisionStates: &ModelRevisionStates{TargetModelState: FailedToLoad}, + LastFailureInfo: &FailureInfo{ + Reason: ModelLoadFailed, + Message: "For testing", + ExitCode: 1, + }, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + scenario.isvcStatus.UpdateModelTransitionStatus(scenario.transitionStatus, scenario.failureInfo) + + g.Expect(scenario.isvcStatus.ModelStatus).To(gomega.Equal(scenario.expected)) + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_custom_test.go b/pkg/apis/serving/v1beta1/predictor_custom_test.go index aef93c0b83c..b536ef7e5d1 100644 --- a/pkg/apis/serving/v1beta1/predictor_custom_test.go +++ b/pkg/apis/serving/v1beta1/predictor_custom_test.go @@ -25,6 +25,7 @@ import ( "github.com/onsi/gomega/types" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) func TestCustomPredictorValidation(t *testing.T) { @@ -201,3 +202,226 @@ func TestCustomPredictorDefaulter(t *testing.T) { }) } } + +func TestCreateCustomPredictorContainer(t *testing.T) { + + var requestedResource = v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "100", + }, + "memory": resource.MustParse("1Gi"), + }, + Requests: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "90", + }, + "memory": resource.MustParse("1Gi"), + }, + } + var config = InferenceServicesConfig{} + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + isvc InferenceService + expectedContainerSpec *v1.Container + }{ + "ContainerSpecWithCustomImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "custom-predictor", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + PodSpec: PodSpec{ + Containers: []v1.Container{ + { + Image: "custom-predictor:0.1.0", + Args: []string{ + "--model_name", + "someName", + "--http_port", + "8080", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: "hdfs://modelzoo", + }, + }, + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Image: "custom-predictor:0.1.0", + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + Args: []string{ + "--model_name", + "someName", + "--http_port", + "8080", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: "hdfs://modelzoo", + }, + }, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + predictor := scenario.isvc.Spec.Predictor.GetImplementation() + predictor.Default(&config) + res := predictor.GetContainer(metav1.ObjectMeta{Name: "someName", Namespace: "default"}, &scenario.isvc.Spec.Predictor.ComponentExtensionSpec, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expectedContainerSpec)) { + t.Errorf("got %q, want %q", res, scenario.expectedContainerSpec) + } + }) + } +} + +func TestCustomPredictorIsMMS(t *testing.T) { + g := gomega.NewGomegaWithT(t) + config := InferenceServicesConfig{} + + defaultResource = v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1"), + v1.ResourceMemory: resource.MustParse("2Gi"), + } + + mmsCase := false + scenarios := map[string]struct { + spec PredictorSpec + expected bool + }{ + "DefaultResources": { + spec: PredictorSpec{ + PodSpec: PodSpec{ + Containers: []v1.Container{ + { + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: "hdfs://modelzoo", + }, + }, + }, + }, + }, + }, + expected: mmsCase, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + customPredictor := NewCustomPredictor(&scenario.spec.PodSpec) + res := customPredictor.IsMMS(&config) + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %t, want %t", res, scenario.expected) + } + }) + } +} + +func TestCustomPredictorIsFrameworkSupported(t *testing.T) { + g := gomega.NewGomegaWithT(t) + framework := "framework" + config := InferenceServicesConfig{} + + defaultResource = v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1"), + v1.ResourceMemory: resource.MustParse("2Gi"), + } + + scenarios := map[string]struct { + spec PredictorSpec + expected bool + }{ + "DefaultResources": { + spec: PredictorSpec{ + PodSpec: PodSpec{ + Containers: []v1.Container{ + { + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: "hdfs://modelzoo", + }, + }, + }, + }, + }, + }, + expected: true, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + customPredictor := NewCustomPredictor(&scenario.spec.PodSpec) + res := customPredictor.IsFrameworkSupported(framework, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %t, want %t", res, scenario.expected) + } + }) + } +} + +func TestCustomPredictorGetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "Default protocol": { + spec: PredictorSpec{ + PodSpec: PodSpec{ + Containers: []v1.Container{ + { + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: "s3://modelzoo", + }, + }, + }, + }, + }, + }, + matcher: gomega.Equal(constants.ProtocolV1), + }, + "protocol v2": { + spec: PredictorSpec{ + PodSpec: PodSpec{ + Containers: []v1.Container{ + { + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: "s3://modelzoo", + }, + { + Name: constants.CustomSpecProtocolEnvVarKey, + Value: string(constants.ProtocolV2), + }, + }, + }, + }, + }, + }, + matcher: gomega.Equal(constants.ProtocolV2), + }, + } + for _, scenario := range scenarios { + customPredictor := NewCustomPredictor(&scenario.spec.PodSpec) + protocol := customPredictor.GetProtocol() + g.Expect(protocol).To(scenario.matcher) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_lightgbm_test.go b/pkg/apis/serving/v1beta1/predictor_lightgbm_test.go index a94ea529a03..0ab6ed3e439 100644 --- a/pkg/apis/serving/v1beta1/predictor_lightgbm_test.go +++ b/pkg/apis/serving/v1beta1/predictor_lightgbm_test.go @@ -20,6 +20,8 @@ import ( "testing" "github.com/golang/protobuf/proto" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/kserve/kserve/pkg/constants" "github.com/onsi/gomega" "github.com/onsi/gomega/types" @@ -120,3 +122,168 @@ func TestLightGBMDefaulter(t *testing.T) { }) } } + +func TestCreateLightGBMModelServingContainer(t *testing.T) { + + var requestedResource = v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.MustParse("100m"), + }, + Requests: v1.ResourceList{ + "cpu": resource.MustParse("90m"), + }, + } + var config = InferenceServicesConfig{} + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + isvc InferenceService + expectedContainerSpec *v1.Container + }{ + "ContainerSpecWithDefaultImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "lightgbm", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithCustomImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "lightgbm", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + Container: v1.Container{ + Image: "customImage:0.1.0", + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Image: "customImage:0.1.0", + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithContainerConcurrency": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "lightgbm", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + ComponentExtensionSpec: ComponentExtensionSpec{ + ContainerConcurrency: proto.Int64(1), + }, + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithWorker": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "lightgbm", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + ComponentExtensionSpec: ComponentExtensionSpec{ + ContainerConcurrency: proto.Int64(2), + }, + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + Args: []string{ + constants.ArgumentWorkers + "=1", + }, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + Args: []string{ + "--workers=1", + }, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + predictor := scenario.isvc.Spec.Predictor.GetImplementation() + predictor.Default(&config) + res := predictor.GetContainer(metav1.ObjectMeta{Name: "someName"}, &scenario.isvc.Spec.Predictor.ComponentExtensionSpec, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expectedContainerSpec)) { + t.Errorf("got %q, want %q", res, scenario.expectedContainerSpec) + } + }) + } +} + +func TestLightGBMGetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + config := InferenceServicesConfig{} + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "DefaultProtocol": { + spec: PredictorSpec{ + LightGBM: &LightGBMSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + }, + }, + }, + matcher: gomega.Equal(constants.ProtocolV1), + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + scenario.spec.LightGBM.Default(&config) + protocol := scenario.spec.LightGBM.GetProtocol() + g.Expect(protocol).To(scenario.matcher) + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_model_test.go b/pkg/apis/serving/v1beta1/predictor_model_test.go index 7caf733729e..b13367a9701 100644 --- a/pkg/apis/serving/v1beta1/predictor_model_test.go +++ b/pkg/apis/serving/v1beta1/predictor_model_test.go @@ -23,6 +23,7 @@ import ( "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" "github.com/kserve/kserve/pkg/constants" "github.com/onsi/gomega" + "github.com/onsi/gomega/types" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -297,7 +298,10 @@ func TestGetSupportingRuntimes(t *testing.T) { } s := runtime.NewScheme() - v1alpha1.AddToScheme(s) + err := v1alpha1.AddToScheme(s) + if err != nil { + t.Errorf("unable to add scheme : %v", err) + } mockClient := fake.NewClientBuilder().WithLists(runtimes, clusterRuntimes).WithScheme(s).Build() for name, scenario := range scenarios { @@ -310,3 +314,94 @@ func TestGetSupportingRuntimes(t *testing.T) { } } + +func TestModelPredictorGetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + var storageUri = "s3://test/model" + isvcConfig := &InferenceServicesConfig{} + objectMeta := metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + } + componentSpec := &ComponentExtensionSpec{ + MinReplicas: GetIntReference(3), + MaxReplicas: 2, + } + scenarios := map[string]struct { + spec *ModelSpec + expected v1.Container + }{ + "ContainerSpecified": { + spec: &ModelSpec{ + ModelFormat: ModelFormat{ + Name: "tensorflow", + }, + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: &storageUri, + Container: v1.Container{ + Name: "foo", + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: storageUri, + }, + }, + }, + }, + }, + expected: v1.Container{ + Name: "foo", + Env: []v1.EnvVar{ + { + Name: "STORAGE_URI", + Value: storageUri, + }, + }, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + container := scenario.spec.GetContainer(objectMeta, componentSpec, isvcConfig) + g.Expect(*container).To(gomega.Equal(scenario.expected)) + }) + } +} + +func TestModelPredictorGetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + spec *ModelSpec + matcher types.GomegaMatcher + }{ + "DefaultProtocol": { + spec: &ModelSpec{ + ModelFormat: ModelFormat{ + Name: "tensorflow", + }, + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://test/model"), + }, + }, + matcher: gomega.Equal(constants.ProtocolV1), + }, + "ProtocolV2Specified": { + spec: &ModelSpec{ + ModelFormat: ModelFormat{ + Name: "tensorflow", + }, + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://test/model"), + ProtocolVersion: (*constants.InferenceServiceProtocol)(proto.String(string(constants.ProtocolV2))), + }, + }, + matcher: gomega.Equal(constants.ProtocolV2), + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + protocol := scenario.spec.GetProtocol() + g.Expect(protocol).To(scenario.matcher) + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_onnxruntime_test.go b/pkg/apis/serving/v1beta1/predictor_onnxruntime_test.go index b17691dcb3b..dd63ae54696 100644 --- a/pkg/apis/serving/v1beta1/predictor_onnxruntime_test.go +++ b/pkg/apis/serving/v1beta1/predictor_onnxruntime_test.go @@ -17,6 +17,7 @@ limitations under the License. package v1beta1 import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "testing" "github.com/golang/protobuf/proto" @@ -140,3 +141,75 @@ func TestONNXRuntimeDefaulter(t *testing.T) { }) } } + +func TestONNXRuntimeSpec_GetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{Name: constants.InferenceServiceContainerName} + scenarios := map[string]struct { + spec PredictorSpec + }{ + "simple": { + spec: PredictorSpec{ + ONNX: &ONNXRuntimeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.ONNX.GetContainer(metadata, &scenario.spec.ComponentExtensionSpec, nil) + if !g.Expect(res).To(gomega.Equal(&scenario.spec.ONNX.Container)) { + t.Errorf("got %v, want %v", res, scenario.spec.ONNX.Container) + } + }) + } +} + +func TestONNXRuntimeSpec_GetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected constants.InferenceServiceProtocol + }{ + "default": { + spec: PredictorSpec{ + ONNX: &ONNXRuntimeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: constants.ProtocolV1, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.ONNX.GetProtocol() + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_paddle_test.go b/pkg/apis/serving/v1beta1/predictor_paddle_test.go index 8ddc4c765b2..a0ea58dd79b 100644 --- a/pkg/apis/serving/v1beta1/predictor_paddle_test.go +++ b/pkg/apis/serving/v1beta1/predictor_paddle_test.go @@ -17,6 +17,7 @@ limitations under the License. package v1beta1 import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "testing" "github.com/golang/protobuf/proto" @@ -113,3 +114,75 @@ func TestPaddleDefaulter(t *testing.T) { }) } } + +func TestPaddleServerSpec_GetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{Name: constants.InferenceServiceContainerName} + scenarios := map[string]struct { + spec PredictorSpec + }{ + "simple": { + spec: PredictorSpec{ + Paddle: &PaddleServerSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.Paddle.GetContainer(metadata, &scenario.spec.ComponentExtensionSpec, nil) + if !g.Expect(res).To(gomega.Equal(&scenario.spec.Paddle.Container)) { + t.Errorf("got %v, want %v", res, scenario.spec.Paddle.Container) + } + }) + } +} + +func TestPaddleServerSpec_GetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected constants.InferenceServiceProtocol + }{ + "default": { + spec: PredictorSpec{ + Paddle: &PaddleServerSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: constants.ProtocolV1, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.Paddle.GetProtocol() + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_pmml_test.go b/pkg/apis/serving/v1beta1/predictor_pmml_test.go index c77e05ca523..f96fd437eea 100644 --- a/pkg/apis/serving/v1beta1/predictor_pmml_test.go +++ b/pkg/apis/serving/v1beta1/predictor_pmml_test.go @@ -17,6 +17,7 @@ limitations under the License. package v1beta1 import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "testing" "github.com/golang/protobuf/proto" @@ -121,3 +122,75 @@ func TestPMMLDefaulter(t *testing.T) { }) } } + +func TestPMMLSpec_GetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected constants.InferenceServiceProtocol + }{ + "default": { + spec: PredictorSpec{ + PMML: &PMMLSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: constants.ProtocolV1, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.PMML.GetProtocol() + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} + +func TestPMMLSpec_GetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{Name: constants.InferenceServiceContainerName} + scenarios := map[string]struct { + spec PredictorSpec + }{ + "simple": { + spec: PredictorSpec{ + PMML: &PMMLSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.PMML.GetContainer(metadata, &scenario.spec.ComponentExtensionSpec, nil) + if !g.Expect(res).To(gomega.Equal(&scenario.spec.PMML.Container)) { + t.Errorf("got %v, want %v", res, scenario.spec.PMML.Container) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_sklearn_test.go b/pkg/apis/serving/v1beta1/predictor_sklearn_test.go index b26ae6a367e..5d593731b7a 100644 --- a/pkg/apis/serving/v1beta1/predictor_sklearn_test.go +++ b/pkg/apis/serving/v1beta1/predictor_sklearn_test.go @@ -17,9 +17,12 @@ limitations under the License. package v1beta1 import ( + "strconv" "testing" "github.com/golang/protobuf/proto" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/kserve/kserve/pkg/constants" "github.com/onsi/gomega" "github.com/onsi/gomega/types" @@ -148,3 +151,431 @@ func TestSKLearnDefaulter(t *testing.T) { }) } } + +func TestCreateSKLearnModelServingContainerV1(t *testing.T) { + var requestedResource = v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "100", + }, + }, + Requests: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "90", + }, + }, + } + var config = InferenceServicesConfig{} + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + isvc InferenceService + expectedContainerSpec *v1.Container + }{ + "ContainerSpecWithoutRuntime": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithDefaultImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithCustomImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + Container: v1.Container{ + Image: "customImage:0.1.0", + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Image: "customImage:0.1.0", + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithContainerConcurrency": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + ComponentExtensionSpec: ComponentExtensionSpec{ + ContainerConcurrency: proto.Int64(1), + }, + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithWorkerArg": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + ComponentExtensionSpec: ComponentExtensionSpec{ + ContainerConcurrency: proto.Int64(4), + }, + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + Args: []string{ + "--workers=1", + }, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + Args: []string{ + "--workers=1", + }, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + predictor := scenario.isvc.Spec.Predictor.GetImplementation() + predictor.Default(&config) + res := predictor.GetContainer(metav1.ObjectMeta{Name: "someName"}, &scenario.isvc.Spec.Predictor.ComponentExtensionSpec, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expectedContainerSpec)) { + t.Errorf("got %q, want %q", res, scenario.expectedContainerSpec) + } + }) + } +} + +func TestCreateSKLearnModelServingContainerV2(t *testing.T) { + protocolV2 := constants.ProtocolV2 + + var requestedResource = v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "100", + }, + }, + Requests: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "90", + }, + }, + } + var config = InferenceServicesConfig{} + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + isvc InferenceService + expectedContainerSpec *v1.Container + }{ + "ContainerSpecWithDefaultImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + ProtocolVersion: &protocolV2, + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithCustomImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + ProtocolVersion: &protocolV2, + Container: v1.Container{ + Image: "customImage:0.1.0", + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Image: "customImage:0.1.0", + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithoutStorageURI": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "sklearn", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + ProtocolVersion: &protocolV2, + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + predictor := scenario.isvc.Spec.Predictor.GetImplementation() + predictor.Default(&config) + res := predictor.GetContainer(scenario.isvc.ObjectMeta, &scenario.isvc.Spec.Predictor.ComponentExtensionSpec, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expectedContainerSpec)) { + t.Errorf("got %q, want %q", res, scenario.expectedContainerSpec) + } + }) + } +} + +func TestSKLearnGetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "DefaultProtocol": { + spec: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + matcher: gomega.Equal(constants.ProtocolV1), + }, + "ProtocolV2": { + spec: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + ProtocolVersion: (*constants.InferenceServiceProtocol)(proto.String(string(constants.ProtocolV2))), + }, + }, + }, + matcher: gomega.Equal(constants.ProtocolV2), + }, + } + for _, scenario := range scenarios { + protocol := scenario.spec.SKLearn.GetProtocol() + g.Expect(protocol).To(scenario.matcher) + } +} + +func TestSKLearnSpec_GetDefaultsV2(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{ + Name: "test", + } + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "storage uri is nil": { + spec: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerModelImplementationEnv, + Value: constants.MLServerSKLearnImplementation, + }, + }), + }, + "storage uri is not nil": { + spec: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://kserve/model"), + }, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerModelImplementationEnv, + Value: constants.MLServerSKLearnImplementation, + }, + { + Name: constants.MLServerModelNameEnv, + Value: metadata.Name, + }, + { + Name: constants.MLServerModelURIEnv, + Value: constants.DefaultModelLocalMountPath, + }, + }), + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.SKLearn.getDefaultsV2(metadata) + if !g.Expect(res).To(scenario.matcher) { + t.Errorf("got %q, want %q", res, scenario.matcher) + } + }) + } +} + +func TestSKLearnSpec_GetEnvVarsV2(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "storage uri is nil": { + spec: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerHTTPPortEnv, + Value: strconv.Itoa(int(constants.MLServerISRestPort)), + }, + { + Name: constants.MLServerGRPCPortEnv, + Value: strconv.Itoa(int(constants.MLServerISGRPCPort)), + }, + { + Name: constants.MLServerModelsDirEnv, + Value: constants.DefaultModelLocalMountPath, + }, + { + Name: constants.MLServerLoadModelsStartupEnv, + Value: strconv.FormatBool(false), + }, + }), + }, + "storage uri is not nil": { + spec: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://kserve/model"), + }, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerHTTPPortEnv, + Value: strconv.Itoa(int(constants.MLServerISRestPort)), + }, + { + Name: constants.MLServerGRPCPortEnv, + Value: strconv.Itoa(int(constants.MLServerISGRPCPort)), + }, + { + Name: constants.MLServerModelsDirEnv, + Value: constants.DefaultModelLocalMountPath, + }, + }), + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.SKLearn.getEnvVarsV2() + if !g.Expect(res).To(scenario.matcher) { + t.Errorf("got %q, want %q", res, scenario.matcher) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_test.go b/pkg/apis/serving/v1beta1/predictor_test.go new file mode 100644 index 00000000000..fc8b941980d --- /dev/null +++ b/pkg/apis/serving/v1beta1/predictor_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2022 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1beta1 + +import ( + "testing" + + "github.com/golang/protobuf/proto" + "github.com/onsi/gomega" + v1 "k8s.io/api/core/v1" +) + +func makeTestPredictorSpec() *PredictorSpec { + return &PredictorSpec{ + PyTorch: &TorchServeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + RuntimeVersion: proto.String("0.4.1"), + }, + }, + } +} + +func TestGetPredictorImplementations(t *testing.T) { + g := gomega.NewGomegaWithT(t) + spec := makeTestPredictorSpec() + implementations := spec.GetPredictorImplementations() + g.Expect(len(implementations)).ShouldNot(gomega.BeZero()) + g.Expect(implementations[0]).Should(gomega.Equal(spec.PyTorch)) + + spec.PyTorch = nil + implementations = spec.GetPredictorImplementations() + g.Expect(len(implementations)).Should(gomega.BeZero()) + + spec.PodSpec.Containers = []v1.Container{ + { + Name: "Test-Container", + Image: "test/predictor", + }, + } + implementations = spec.GetPredictorImplementations() + g.Expect(len(implementations)).ShouldNot(gomega.BeZero()) + g.Expect(implementations[0]).Should(gomega.Equal(NewCustomPredictor(&spec.PodSpec))) +} + +func TestGetPredictorImplementation(t *testing.T) { + g := gomega.NewGomegaWithT(t) + spec := makeTestPredictorSpec() + expected := spec.PyTorch + implementation := spec.GetPredictorImplementation() + g.Expect(*implementation).Should(gomega.Equal(expected)) + + spec.PyTorch = nil + implementation = spec.GetPredictorImplementation() + g.Expect(implementation).Should(gomega.BeNil()) +} diff --git a/pkg/apis/serving/v1beta1/predictor_tfserving_test.go b/pkg/apis/serving/v1beta1/predictor_tfserving_test.go index 0c72dd3f6e3..d4eeb8c3c8c 100644 --- a/pkg/apis/serving/v1beta1/predictor_tfserving_test.go +++ b/pkg/apis/serving/v1beta1/predictor_tfserving_test.go @@ -18,6 +18,7 @@ package v1beta1 import ( "fmt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "testing" "github.com/golang/protobuf/proto" @@ -148,3 +149,75 @@ func TestTensorflowDefaulter(t *testing.T) { }) } } + +func TestTFServingSpec_GetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{Name: constants.InferenceServiceContainerName} + scenarios := map[string]struct { + spec PredictorSpec + }{ + "simple": { + spec: PredictorSpec{ + Tensorflow: &TFServingSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.Tensorflow.GetContainer(metadata, &scenario.spec.ComponentExtensionSpec, nil) + if !g.Expect(res).To(gomega.Equal(&scenario.spec.Tensorflow.Container)) { + t.Errorf("got %v, want %v", res, scenario.spec.Triton.Container) + } + }) + } +} + +func TestTFServingSpec_GetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected constants.InferenceServiceProtocol + }{ + "default": { + spec: PredictorSpec{ + Tensorflow: &TFServingSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: constants.ProtocolV1, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.Tensorflow.GetProtocol() + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_torchserve_test.go b/pkg/apis/serving/v1beta1/predictor_torchserve_test.go index c4b01dd4cfb..a0f5e0b2ee7 100644 --- a/pkg/apis/serving/v1beta1/predictor_torchserve_test.go +++ b/pkg/apis/serving/v1beta1/predictor_torchserve_test.go @@ -18,6 +18,7 @@ package v1beta1 import ( "fmt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "testing" "github.com/golang/protobuf/proto" @@ -173,3 +174,67 @@ func TestTorchServeDefaulter(t *testing.T) { }) } } + +func TestTorchServeSpec_GetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected constants.InferenceServiceProtocol + }{ + "default": { + spec: PredictorSpec{ + PyTorch: &TorchServeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: constants.ProtocolV1, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.PyTorch.GetProtocol() + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} + +func TestTorchServeSpec_GetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{Name: constants.InferenceServiceContainerName} + scenarios := map[string]struct { + spec PredictorSpec + }{ + "simple": { + spec: PredictorSpec{ + PyTorch: &TorchServeSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.PyTorch.GetContainer(metadata, &scenario.spec.ComponentExtensionSpec, nil) + if !g.Expect(res).To(gomega.Equal(&scenario.spec.PyTorch.Container)) { + t.Errorf("got %v, want %v", res, scenario.spec.PyTorch.Container) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_triton_test.go b/pkg/apis/serving/v1beta1/predictor_triton_test.go index a89ce8ccf97..f8fce8c1981 100644 --- a/pkg/apis/serving/v1beta1/predictor_triton_test.go +++ b/pkg/apis/serving/v1beta1/predictor_triton_test.go @@ -17,6 +17,7 @@ limitations under the License. package v1beta1 import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "testing" "github.com/golang/protobuf/proto" @@ -122,3 +123,131 @@ func TestTritonDefaulter(t *testing.T) { }) } } + +func TestTritonSpec_GetContainer(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{Name: constants.InferenceServiceContainerName} + scenarios := map[string]struct { + spec PredictorSpec + }{ + "simple": { + spec: PredictorSpec{ + Triton: &TritonSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.Triton.GetContainer(metadata, &scenario.spec.ComponentExtensionSpec, nil) + if !g.Expect(res).To(gomega.Equal(&scenario.spec.Triton.Container)) { + t.Errorf("got %v, want %v", res, scenario.spec.Triton.Container) + } + }) + } +} + +func TestTritonSpec_Default(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected *TritonSpec + }{ + "simple": { + spec: PredictorSpec{ + Triton: &TritonSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: &TritonSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.MustParse("1"), + "memory": resource.MustParse("2Gi"), + }, + Requests: v1.ResourceList{ + "memory": resource.MustParse("2Gi"), + "cpu": resource.MustParse("1"), + }, + }, + }, + }, + }, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + scenario.spec.Triton.Default(nil) + if !g.Expect(scenario.spec.Triton).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} + +func TestTritonSpec_GetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + expected constants.InferenceServiceProtocol + }{ + "default": { + spec: PredictorSpec{ + Triton: &TritonSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("s3://modelzoo"), + Container: v1.Container{ + Image: "image:0.1", + Args: nil, + Env: nil, + Resources: v1.ResourceRequirements{}, + }, + }, + }, + ComponentExtensionSpec: ComponentExtensionSpec{}, + }, + expected: constants.ProtocolV2, + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.Triton.GetProtocol() + if !g.Expect(res).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", scenario.spec.Triton, scenario.expected) + } + }) + } +} diff --git a/pkg/apis/serving/v1beta1/predictor_xgboost_test.go b/pkg/apis/serving/v1beta1/predictor_xgboost_test.go index 545cae5159a..380f51484e4 100644 --- a/pkg/apis/serving/v1beta1/predictor_xgboost_test.go +++ b/pkg/apis/serving/v1beta1/predictor_xgboost_test.go @@ -17,9 +17,12 @@ limitations under the License. package v1beta1 import ( + "strconv" "testing" "github.com/golang/protobuf/proto" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/kserve/kserve/pkg/constants" "github.com/onsi/gomega" "github.com/onsi/gomega/types" @@ -91,6 +94,28 @@ func TestXGBoostDefaulter(t *testing.T) { spec PredictorSpec expected PredictorSpec }{ + "DefaultRuntimeVersion": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + expected: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + //RuntimeVersion: proto.String("v0.4.0"), + ProtocolVersion: &protocolV1, + Container: v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: v1.ResourceRequirements{ + Requests: defaultResource, + Limits: defaultResource, + }, + }, + }, + }, + }, + }, "DefaultRuntimeVersionAndProtocol": { spec: PredictorSpec{ XGBoost: &XGBoostSpec{ @@ -102,6 +127,7 @@ func TestXGBoostDefaulter(t *testing.T) { expected: PredictorSpec{ XGBoost: &XGBoostSpec{ PredictorExtensionSpec: PredictorExtensionSpec{ + //RuntimeVersion: proto.String("v0.1.2"), ProtocolVersion: &protocolV2, Container: v1.Container{ Name: constants.InferenceServiceContainerName, @@ -149,3 +175,405 @@ func TestXGBoostDefaulter(t *testing.T) { }) } } + +func TestCreateXGBoostModelServingContainerV1(t *testing.T) { + + var requestedResource = v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.MustParse("100m"), + }, + Requests: v1.ResourceList{ + "cpu": resource.MustParse("90m"), + }, + } + var config = InferenceServicesConfig{} + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + isvc InferenceService + expectedContainerSpec *v1.Container + }{ + "ContainerSpecWithDefaultImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithCustomImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + Container: v1.Container{ + Image: "customImage:0.1.0", + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Image: "customImage:0.1.0", + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithContainerConcurrency": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + ComponentExtensionSpec: ComponentExtensionSpec{ + ContainerConcurrency: proto.Int64(1), + }, + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithWorker": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + ComponentExtensionSpec: ComponentExtensionSpec{ + ContainerConcurrency: proto.Int64(2), + }, + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + Container: v1.Container{ + Resources: requestedResource, + Args: []string{ + constants.ArgumentWorkers + "=1", + }, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + Args: []string{ + "--workers=1", + }, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + predictor := scenario.isvc.Spec.Predictor.GetImplementation() + predictor.Default(&config) + res := predictor.GetContainer(metav1.ObjectMeta{Name: "someName"}, &scenario.isvc.Spec.Predictor.ComponentExtensionSpec, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expectedContainerSpec)) { + t.Errorf("got %q, want %q", res, scenario.expectedContainerSpec) + } + }) + } +} + +func TestCreateXGBoostModelServingContainerV2(t *testing.T) { + protocolV2 := constants.ProtocolV2 + + var requestedResource = v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "100", + }, + }, + Requests: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "90", + }, + }, + } + var config = InferenceServicesConfig{} + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + isvc InferenceService + expectedContainerSpec *v1.Container + }{ + "ContainerSpecWithDefaultImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + RuntimeVersion: proto.String("0.1.0"), + ProtocolVersion: &protocolV2, + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithCustomImage": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://someUri"), + ProtocolVersion: &protocolV2, + Container: v1.Container{ + Image: "customImage:0.1.0", + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Image: "customImage:0.1.0", + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + "ContainerSpecWithoutStorageURI": { + isvc: InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "xgboost", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + ProtocolVersion: &protocolV2, + Container: v1.Container{ + Resources: requestedResource, + }, + }, + }, + }, + }, + }, + expectedContainerSpec: &v1.Container{ + Name: constants.InferenceServiceContainerName, + Resources: requestedResource, + }, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + predictor := scenario.isvc.Spec.Predictor.GetImplementation() + predictor.Default(&config) + res := predictor.GetContainer(scenario.isvc.ObjectMeta, &scenario.isvc.Spec.Predictor.ComponentExtensionSpec, &config) + if !g.Expect(res).To(gomega.Equal(scenario.expectedContainerSpec)) { + t.Errorf("got %q, want %q", res, scenario.expectedContainerSpec) + } + }) + } +} + +func TestXGBoostGetProtocol(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "DefaultProtocol": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + matcher: gomega.Equal(constants.ProtocolV1), + }, + "ProtocolV2": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + ProtocolVersion: (*constants.InferenceServiceProtocol)(proto.String(string(constants.ProtocolV2))), + }, + }, + }, + matcher: gomega.Equal(constants.ProtocolV2), + }, + } + for _, scenario := range scenarios { + protocol := scenario.spec.XGBoost.GetProtocol() + g.Expect(protocol).To(scenario.matcher) + } +} + +func TestXGBoostSpec_GetEnvVarsV2(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "storage uri is nil": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerHTTPPortEnv, + Value: strconv.Itoa(int(constants.MLServerISRestPort)), + }, + { + Name: constants.MLServerGRPCPortEnv, + Value: strconv.Itoa(int(constants.MLServerISGRPCPort)), + }, + { + Name: constants.MLServerModelsDirEnv, + Value: constants.DefaultModelLocalMountPath, + }, + { + Name: constants.MLServerLoadModelsStartupEnv, + Value: strconv.FormatBool(false), + }, + }), + }, + "storage uri is not nil": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://kserve/model"), + }, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerHTTPPortEnv, + Value: strconv.Itoa(int(constants.MLServerISRestPort)), + }, + { + Name: constants.MLServerGRPCPortEnv, + Value: strconv.Itoa(int(constants.MLServerISGRPCPort)), + }, + { + Name: constants.MLServerModelsDirEnv, + Value: constants.DefaultModelLocalMountPath, + }, + }), + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.XGBoost.getEnvVarsV2() + if !g.Expect(res).To(scenario.matcher) { + t.Errorf("got %q, want %q", res, scenario.matcher) + } + }) + } +} + +func TestXGBoostSpec_GetDefaultsV2(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + metadata := metav1.ObjectMeta{ + Name: "test", + } + scenarios := map[string]struct { + spec PredictorSpec + matcher types.GomegaMatcher + }{ + "storage uri is nil": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{}, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerModelImplementationEnv, + Value: constants.MLServerXGBoostImplementation, + }, + }), + }, + "storage uri is not nil": { + spec: PredictorSpec{ + XGBoost: &XGBoostSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + StorageURI: proto.String("gs://kserve/model"), + }, + }, + }, + matcher: gomega.Equal([]v1.EnvVar{ + { + Name: constants.MLServerModelImplementationEnv, + Value: constants.MLServerXGBoostImplementation, + }, + { + Name: constants.MLServerModelNameEnv, + Value: metadata.Name, + }, + { + Name: constants.MLServerModelURIEnv, + Value: constants.DefaultModelLocalMountPath, + }, + }), + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := scenario.spec.XGBoost.getDefaultsV2(metadata) + if !g.Expect(res).To(scenario.matcher) { + t.Errorf("got %q, want %q", res, scenario.matcher) + } + }) + } +} diff --git a/pkg/batcher/handler_test.go b/pkg/batcher/handler_test.go index 4f3715988ec..0efb9107c2b 100644 --- a/pkg/batcher/handler_test.go +++ b/pkg/batcher/handler_test.go @@ -86,3 +86,87 @@ func TestBatcher(t *testing.T) { <-responseChan wg.Wait() } + +// Tests batcher when inference response code is other than 200 +func TestBatcherFail(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + logger, _ := pkglogging.NewLogger("", "INFO") + + responseChan := make(chan Response) + // Start a local HTTP server + predictor := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + b, err := ioutil.ReadAll(req.Body) + g.Expect(err).To(gomega.BeNil()) + var request Request + err = json.Unmarshal(b, &request) + g.Expect(err).To(gomega.BeNil()) + logger.Infof("Get request %v", string(b)) + response := Response{} + responseChan <- response + responseBytes, err := json.Marshal(response) + g.Expect(err).To(gomega.BeNil()) + rw.WriteHeader(500) + _, err = rw.Write(responseBytes) + g.Expect(err).To(gomega.BeNil()) + })) + // Close the server when test finishes + defer predictor.Close() + predictorSvcUrl, err := url.Parse(predictor.URL) + logger.Infof("predictor url %s", predictorSvcUrl) + g.Expect(err).To(gomega.BeNil()) + httpProxy := httputil.NewSingleHostReverseProxy(predictorSvcUrl) + batchHandler := New(32, 50, httpProxy, logger) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go serveRequest(batchHandler, &wg, i) + } + //var responseBytes []byte + <-responseChan + wg.Wait() +} + +// Tests default max batch size and max latency +func TestBatcherDefaults(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + logger, _ := pkglogging.NewLogger("", "INFO") + + responseChan := make(chan Response) + // Start a local HTTP server + predictor := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + b, err := ioutil.ReadAll(req.Body) + g.Expect(err).To(gomega.BeNil()) + var request Request + err = json.Unmarshal(b, &request) + g.Expect(err).To(gomega.BeNil()) + logger.Infof("Get request %v", string(b)) + response := Response{ + Predictions: request.Instances, + } + responseChan <- response + responseBytes, err := json.Marshal(response) + g.Expect(err).To(gomega.BeNil()) + rw.WriteHeader(500) + _, err = rw.Write(responseBytes) + g.Expect(err).To(gomega.BeNil()) + })) + // Close the server when test finishes + defer predictor.Close() + predictorSvcUrl, err := url.Parse(predictor.URL) + logger.Infof("predictor url %s", predictorSvcUrl) + g.Expect(err).To(gomega.BeNil()) + httpProxy := httputil.NewSingleHostReverseProxy(predictorSvcUrl) + batchHandler := New(-1, -1, httpProxy, logger) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go serveRequest(batchHandler, &wg, i) + } + //var responseBytes []byte + <-responseChan + wg.Wait() + g.Expect(batchHandler.MaxBatchSize).To(gomega.Equal(MaxBatchSize)) + g.Expect(batchHandler.MaxLatency).To(gomega.Equal(MaxLatency)) +} diff --git a/pkg/controller/v1beta1/inferenceservice/reconcilers/ingress/ingress_reconciler_test.go b/pkg/controller/v1beta1/inferenceservice/reconcilers/ingress/ingress_reconciler_test.go index 1d2e3bb93ab..0f3df3036aa 100644 --- a/pkg/controller/v1beta1/inferenceservice/reconcilers/ingress/ingress_reconciler_test.go +++ b/pkg/controller/v1beta1/inferenceservice/reconcilers/ingress/ingress_reconciler_test.go @@ -17,11 +17,11 @@ limitations under the License. package ingress import ( - "testing" - "github.com/google/go-cmp/cmp" "github.com/kserve/kserve/pkg/apis/serving/v1beta1" "github.com/kserve/kserve/pkg/constants" + "github.com/onsi/gomega" + gomegaTypes "github.com/onsi/gomega/types" istiov1alpha3 "istio.io/api/networking/v1alpha3" "istio.io/client-go/pkg/apis/networking/v1alpha3" corev1 "k8s.io/api/core/v1" @@ -29,6 +29,8 @@ import ( "knative.dev/pkg/apis" duckv1 "knative.dev/pkg/apis/duck/v1" "knative.dev/pkg/network" + "net/url" + "testing" ) func TestCreateVirtualService(t *testing.T) { @@ -543,3 +545,140 @@ func createInferenceServiceWithHostname(hostName string) *v1beta1.InferenceServi }, } } + +func TestGetServiceUrl(t *testing.T) { + g := gomega.NewGomegaWithT(t) + serviceName := "my-model" + namespace := "test" + isvcAnnotations := map[string]string{"test": "test", "kubectl.kubernetes.io/last-applied-configuration": "test"} + labels := map[string]string{"test": "test"} + predictorUrl, _ := url.Parse("http://my-model-predictor-default.example.com") + transformerUrl, _ := url.Parse("http://my-model-transformer-default.example.com") + urlScheme := "http" + + cases := map[string]struct { + isvc *v1beta1.InferenceService + matcher gomegaTypes.GomegaMatcher + }{ + "component is empty": { + isvc: &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + Annotations: isvcAnnotations, + Labels: labels, + }, + Spec: v1beta1.InferenceServiceSpec{ + Predictor: v1beta1.PredictorSpec{}, + }, + }, + matcher: gomega.Equal(""), + }, + "predictor url is empty": { + isvc: &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + Annotations: isvcAnnotations, + Labels: labels, + }, + Spec: v1beta1.InferenceServiceSpec{ + Predictor: v1beta1.PredictorSpec{ + SKLearn: &v1beta1.SKLearnSpec{}, + }, + }, + Status: v1beta1.InferenceServiceStatus{ + Status: duckv1.Status{}, + Address: nil, + URL: nil, + Components: map[v1beta1.ComponentType]v1beta1.ComponentStatusSpec{ + v1beta1.PredictorComponent: v1beta1.ComponentStatusSpec{}, + }, + ModelStatus: v1beta1.ModelStatus{}, + }, + }, + matcher: gomega.Equal(""), + }, + "predictor url is not empty": { + isvc: &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + Annotations: isvcAnnotations, + Labels: labels, + }, + Spec: v1beta1.InferenceServiceSpec{ + Predictor: v1beta1.PredictorSpec{ + SKLearn: &v1beta1.SKLearnSpec{}, + }, + }, + Status: v1beta1.InferenceServiceStatus{ + Status: duckv1.Status{}, + Address: nil, + URL: nil, + Components: map[v1beta1.ComponentType]v1beta1.ComponentStatusSpec{ + v1beta1.PredictorComponent: v1beta1.ComponentStatusSpec{ + URL: (*apis.URL)(predictorUrl), + }, + }, + }, + }, + matcher: gomega.Equal("http://my-model.example.com"), + }, + "transformer is not empty": { + isvc: &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + Annotations: isvcAnnotations, + Labels: labels, + }, + Spec: v1beta1.InferenceServiceSpec{ + Predictor: v1beta1.PredictorSpec{ + SKLearn: &v1beta1.SKLearnSpec{}, + }, + Transformer: &v1beta1.TransformerSpec{}, + }, + Status: v1beta1.InferenceServiceStatus{ + Status: duckv1.Status{}, + Address: nil, + URL: nil, + Components: map[v1beta1.ComponentType]v1beta1.ComponentStatusSpec{ + v1beta1.PredictorComponent: v1beta1.ComponentStatusSpec{ + URL: (*apis.URL)(predictorUrl), + }, + v1beta1.TransformerComponent: v1beta1.ComponentStatusSpec{ + URL: (*apis.URL)(transformerUrl), + }, + }, + }, + }, + matcher: gomega.Equal("http://my-model.example.com"), + }, + "predictor is empty": { + isvc: &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + Annotations: isvcAnnotations, + Labels: labels, + }, + Spec: v1beta1.InferenceServiceSpec{}, + Status: v1beta1.InferenceServiceStatus{ + Status: duckv1.Status{}, + Address: nil, + URL: nil, + Components: map[v1beta1.ComponentType]v1beta1.ComponentStatusSpec{}, + }, + }, + matcher: gomega.Equal(""), + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + url := getServiceUrl(tc.isvc, urlScheme, false) + g.Expect(url).Should(tc.matcher) + }) + } +} diff --git a/pkg/controller/v1beta1/inferenceservice/utils/utils_test.go b/pkg/controller/v1beta1/inferenceservice/utils/utils_test.go index 698b87c304c..203a53ed00e 100644 --- a/pkg/controller/v1beta1/inferenceservice/utils/utils_test.go +++ b/pkg/controller/v1beta1/inferenceservice/utils/utils_test.go @@ -1178,3 +1178,50 @@ func TestUpdateImageTag(t *testing.T) { }) } } + +func TestGetDeploymentMode(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + annotations map[string]string + deployConfig *v1beta1.DeployConfig + expected constants.DeploymentModeType + }{ + "RawDeployment": { + annotations: map[string]string{ + constants.DeploymentMode: string(constants.RawDeployment), + }, + deployConfig: &v1beta1.DeployConfig{}, + expected: constants.DeploymentModeType(constants.RawDeployment), + }, + "ServerlessDeployment": { + annotations: map[string]string{ + constants.DeploymentMode: string(constants.Serverless), + }, + deployConfig: &v1beta1.DeployConfig{}, + expected: constants.DeploymentModeType(constants.Serverless), + }, + "ModelMeshDeployment": { + annotations: map[string]string{ + constants.DeploymentMode: string(constants.ModelMeshDeployment), + }, + deployConfig: &v1beta1.DeployConfig{}, + expected: constants.DeploymentModeType(constants.ModelMeshDeployment), + }, + "DefaultDeploymentMode": { + annotations: map[string]string{}, + deployConfig: &v1beta1.DeployConfig{ + DefaultDeploymentMode: string(constants.Serverless), + }, + expected: constants.DeploymentModeType(constants.Serverless), + }, + } + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + deploymentMode := GetDeploymentMode(scenario.annotations, scenario.deployConfig) + if !g.Expect(deploymentMode).To(gomega.Equal(scenario.expected)) { + t.Errorf("got %v, want %v", deploymentMode, scenario.expected) + } + }) + } +} diff --git a/pkg/credentials/service_account_credentials_test.go b/pkg/credentials/service_account_credentials_test.go index 5e31859aa66..f36d4c1298c 100644 --- a/pkg/credentials/service_account_credentials_test.go +++ b/pkg/credentials/service_account_credentials_test.go @@ -18,6 +18,7 @@ package credentials import ( "context" + "github.com/onsi/gomega/types" "testing" "github.com/kserve/kserve/pkg/credentials/azure" @@ -753,3 +754,514 @@ func TestAzureStorageAccessKeyCredentialBuilder(t *testing.T) { g.Expect(c.Delete(context.TODO(), customAzureSecret)).NotTo(gomega.HaveOccurred()) g.Expect(c.Delete(context.TODO(), customOnlyServiceAccount)).NotTo(gomega.HaveOccurred()) } + +func TestCredentialBuilder_CreateStorageSpecSecretEnvs(t *testing.T) { + g := gomega.NewGomegaWithT(t) + namespace := "default" + builder := NewCredentialBulder(c, configMap) + + scenarios := map[string]struct { + secret *v1.Secret + storageKey string + storageSecretName string + overrideParams map[string]string + container *v1.Container + shouldFail bool + matcher types.GomegaMatcher + }{ + "fail on storage secret name is empty": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-secret", + Namespace: namespace, + }, + Data: nil, + }, + storageKey: "", + storageSecretName: "", + overrideParams: make(map[string]string), + container: &v1.Container{}, + shouldFail: true, + matcher: gomega.HaveOccurred(), + }, + "storage spec with empty override params": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "", "bucket": ""}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: false, + matcher: gomega.Equal(&v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_CONFIG", + Value: "", + ValueFrom: &v1.EnvVarSource{ + FieldRef: nil, + ResourceFieldRef: nil, + ConfigMapKeyRef: nil, + SecretKeyRef: &v1.SecretKeySelector{ + LocalObjectReference: v1.LocalObjectReference{ + Name: "storage-secret", + }, + Key: "minio", + Optional: nil, + }, + }, + }, + { + Name: "STORAGE_OVERRIDE_CONFIG", + Value: "{\"bucket\":\"\",\"type\":\"\"}", + }, + }, + }), + }, + "simple storage spec": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "s3", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: false, + matcher: gomega.Equal(&v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_CONFIG", + Value: "", + ValueFrom: &v1.EnvVarSource{ + FieldRef: nil, + ResourceFieldRef: nil, + ConfigMapKeyRef: nil, + SecretKeyRef: &v1.SecretKeySelector{ + LocalObjectReference: v1.LocalObjectReference{ + Name: "storage-secret", + }, + Key: "minio", + Optional: nil, + }, + }, + }, + { + Name: "STORAGE_OVERRIDE_CONFIG", + Value: "{\"bucket\":\"test-bucket\",\"type\":\"s3\"}", + }, + }, + }), + }, + "wrong storage key": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "wrong-key", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "s3", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: true, + matcher: gomega.HaveOccurred(), + }, + "default storage key": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{DefaultStorageSecretKey + "_s3": "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "s3", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: false, + matcher: gomega.Equal(&v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_CONFIG", + Value: "", + ValueFrom: &v1.EnvVarSource{ + FieldRef: nil, + ResourceFieldRef: nil, + ConfigMapKeyRef: nil, + SecretKeyRef: &v1.SecretKeySelector{ + LocalObjectReference: v1.LocalObjectReference{ + Name: "storage-secret", + }, + Key: "default_s3", + Optional: nil, + }, + }, + }, + { + Name: "STORAGE_OVERRIDE_CONFIG", + Value: "{\"bucket\":\"test-bucket\",\"type\":\"s3\"}", + }, + }, + }), + }, + "default storage key with empty storage type": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{DefaultStorageSecretKey: "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: false, + matcher: gomega.Equal(&v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_CONFIG", + Value: "", + ValueFrom: &v1.EnvVarSource{ + FieldRef: nil, + ResourceFieldRef: nil, + ConfigMapKeyRef: nil, + SecretKeyRef: &v1.SecretKeySelector{ + LocalObjectReference: v1.LocalObjectReference{ + Name: "storage-secret", + }, + Key: "default", + Optional: nil, + }, + }, + }, + { + Name: "STORAGE_OVERRIDE_CONFIG", + Value: "{\"bucket\":\"test-bucket\",\"type\":\"\"}", + }, + }, + }), + }, + "storage spec with uri scheme placeholder": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "s3", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "://models/example-model/", + "/mnt/models/", + }, + }, + shouldFail: false, + matcher: gomega.Equal(&v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/example-model/", + "/mnt/models/", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_CONFIG", + Value: "", + ValueFrom: &v1.EnvVarSource{ + FieldRef: nil, + ResourceFieldRef: nil, + ConfigMapKeyRef: nil, + SecretKeyRef: &v1.SecretKeySelector{ + LocalObjectReference: v1.LocalObjectReference{ + Name: "storage-secret", + }, + Key: "minio", + Optional: nil, + }, + }, + }, + { + Name: "STORAGE_OVERRIDE_CONFIG", + Value: "{\"bucket\":\"test-bucket\",\"type\":\"s3\"}", + }, + }, + }), + }, + "hdfs with uri scheme placeholder": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"hdfs": "{\n \"type\": \"hdfs\",\n \"access_key_id\": \"hdfs34\",\n \"secret_access_key\": \"hdfs123\",\n \"endpoint_url\": \"http://hdfs-service.kubeflow\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "hdfs", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "hdfs", "bucket": ""}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "://models/example-model/", + "/mnt/models/", + }, + }, + shouldFail: false, + matcher: gomega.Equal(&v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "hdfs://models/example-model/", + "/mnt/models/", + }, + Env: []v1.EnvVar{ + { + Name: "STORAGE_CONFIG", + Value: "", + ValueFrom: &v1.EnvVarSource{ + FieldRef: nil, + ResourceFieldRef: nil, + ConfigMapKeyRef: nil, + SecretKeyRef: &v1.SecretKeySelector{ + LocalObjectReference: v1.LocalObjectReference{ + Name: "storage-secret", + }, + Key: "hdfs", + Optional: nil, + }, + }, + }, + { + Name: "STORAGE_OVERRIDE_CONFIG", + Value: "{\"bucket\":\"\",\"type\":\"hdfs\"}", + }, + }, + }), + }, + "unsupported storage type": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"gs\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "gs://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: true, + matcher: gomega.HaveOccurred(), + }, + "secret data with syntax error": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n { \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: true, + matcher: gomega.HaveOccurred(), + }, + "fail on storage type is empty": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "", "bucket": "test-bucket"}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "s3://test-bucket/models/", + "/mnt/models/", + }, + }, + shouldFail: true, + matcher: gomega.HaveOccurred(), + }, + "fail on bucket is empty on s3 storage": { + secret: &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "storage-secret", + Namespace: namespace, + }, + StringData: map[string]string{"minio": "{\n \"type\": \"s3\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"\",\n \"region\": \"us-south\"\n }"}, + }, + storageKey: "minio", + storageSecretName: "storage-secret", + overrideParams: map[string]string{"type": "s3", "bucket": ""}, + container: &v1.Container{ + Name: "init-container", + Image: "kserve/init-container:latest", + Args: []string{ + "://models/example-model/", + "/mnt/models/", + }, + }, + shouldFail: true, + matcher: gomega.HaveOccurred(), + }, + } + + for _, tc := range scenarios { + if err := c.Create(context.TODO(), tc.secret); err != nil { + t.Errorf("Failed to create secret %s: %v", "storage-secret", err) + } + err := builder.CreateStorageSpecSecretEnvs(namespace, tc.storageKey, tc.storageSecretName, tc.overrideParams, tc.container) + if !tc.shouldFail { + g.Expect(err).Should(gomega.BeNil()) + g.Expect(tc.container).Should(tc.matcher) + } else { + g.Expect(err).To(tc.matcher) + } + if err := c.Delete(context.TODO(), tc.secret); err != nil { + t.Errorf("Failed to delete secret %s because of: %v", tc.secret.Name, err) + } + } +} diff --git a/pkg/modelconfig/configmap_test.go b/pkg/modelconfig/configmap_test.go index 30b5e732cfa..148ff88abe8 100644 --- a/pkg/modelconfig/configmap_test.go +++ b/pkg/modelconfig/configmap_test.go @@ -21,8 +21,11 @@ import ( "testing" "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/apis/serving/v1beta1" "github.com/kserve/kserve/pkg/constants" + "github.com/kserve/kserve/pkg/controller/v1alpha1/trainedmodel/sharding/memory" testify "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/log" @@ -240,3 +243,41 @@ func getSortedConfigData(input string) (output ModelConfigs, err error) { }) return output, nil } + +func TestCreateEmptyModelConfig(t *testing.T) { + isvc := &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + Annotations: map[string]string{ + constants.DeploymentMode: string(constants.ModelMeshDeployment), + }, + }, + Spec: v1beta1.InferenceServiceSpec{ + Predictor: v1beta1.PredictorSpec{ + Tensorflow: &v1beta1.TFServingSpec{ + PredictorExtensionSpec: v1beta1.PredictorExtensionSpec{ + StorageURI: proto.String("gs://testbucket/testmodel"), + }, + }, + }, + }, + } + shardStrategy := memory.MemoryStrategy{} + shardId := shardStrategy.GetShard(isvc)[0] + expected := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: constants.ModelConfigName(isvc.Name, shardId), + Namespace: isvc.Namespace, + Labels: isvc.Labels, + }, + Data: map[string]string{ + constants.ModelConfigFileName: "[]", + }, + } + + configMap, err := CreateEmptyModelConfig(isvc, shardId) + testify.Nil(t, err) + testify.Equal(t, configMap, expected) + +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 99401a6f19a..390a4507f63 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -17,10 +17,15 @@ limitations under the License. package utils import ( + "errors" "testing" + "github.com/kserve/kserve/pkg/constants" "github.com/kserve/kserve/pkg/credentials/gcs" + "github.com/onsi/gomega" + "github.com/onsi/gomega/types" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" "github.com/google/go-cmp/cmp" ) @@ -387,3 +392,152 @@ func TestMergeEnvs(t *testing.T) { } } } + +func TestIncludesArg(t *testing.T) { + g := gomega.NewGomegaWithT(t) + args := []string{ + constants.ArgumentModelName, + } + scenarios := map[string]struct { + arg string + expected bool + }{ + "SliceContainsArg": { + arg: constants.ArgumentModelName, + expected: true, + }, + "SliceNotContainsArg": { + arg: "NoArg", + expected: false, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := IncludesArg(args, scenario.arg) + g.Expect(res).To(gomega.Equal(scenario.expected)) + }) + } +} + +func TestIsGpuEnabled(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + resource v1.ResourceRequirements + expected bool + }{ + "GpuEnabled": { + resource: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "100", + }, + constants.NvidiaGPUResourceType: resource.MustParse("1"), + }, + Requests: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "90", + }, + constants.NvidiaGPUResourceType: resource.MustParse("1"), + }, + }, + expected: true, + }, + "GPUDisabled": { + resource: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "100", + }, + }, + Requests: v1.ResourceList{ + "cpu": resource.Quantity{ + Format: "90", + }, + }, + }, + expected: false, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := IsGPUEnabled(scenario.resource) + g.Expect(res).To(gomega.Equal(scenario.expected)) + }) + } +} + +func TestFirstNonNilError(t *testing.T) { + g := gomega.NewGomegaWithT(t) + scenarios := map[string]struct { + errors []error + matcher types.GomegaMatcher + }{ + "NoNonNilError": { + errors: []error{ + nil, + nil, + }, + matcher: gomega.BeNil(), + }, + "ContainsError": { + errors: []error{ + nil, + errors.New("First non nil error"), + errors.New("Second non nil error"), + }, + matcher: gomega.Equal(errors.New("First non nil error")), + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + err := FirstNonNilError(scenario.errors) + g.Expect(err).Should(scenario.matcher) + }) + } +} + +func TestRemoveString(t *testing.T) { + g := gomega.NewGomegaWithT(t) + testStrings := []string{ + "Model Tensorflow", + "SKLearn Model", + "Model", + "ModelPytorch", + } + expected := []string{ + "Model Tensorflow", + "SKLearn Model", + "ModelPytorch", + } + res := RemoveString(testStrings, "Model") + g.Expect(res).Should(gomega.Equal(expected)) +} + +func TestIsPrefixSupported(t *testing.T) { + g := gomega.NewGomegaWithT(t) + prefixes := []string{ + "S3://", + "GCS://", + "HTTP://", + "HTTPS://", + } + scenarios := map[string]struct { + input string + expected bool + }{ + "SupportedPrefix": { + input: "GCS://test/model", + expected: true, + }, + "UnSupportedPreifx": { + input: "PVC://test/model", + expected: false, + }, + } + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + res := IsPrefixSupported(scenario.input, prefixes) + g.Expect(res).Should(gomega.Equal(scenario.expected)) + }) + } +} diff --git a/pkg/webhook/admission/pod/agent_injector_test.go b/pkg/webhook/admission/pod/agent_injector_test.go index df091203753..35f19f93847 100644 --- a/pkg/webhook/admission/pod/agent_injector_test.go +++ b/pkg/webhook/admission/pod/agent_injector_test.go @@ -21,6 +21,8 @@ import ( "github.com/kserve/kserve/pkg/apis/serving/v1beta1" "github.com/kserve/kserve/pkg/credentials" + "github.com/onsi/gomega" + "github.com/onsi/gomega/types" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/util/intstr" @@ -368,6 +370,142 @@ func TestAgentInjector(t *testing.T) { }, }, }, + "AddBatcher": { + original: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "deployment", + Namespace: "default", + Annotations: map[string]string{ + constants.BatcherInternalAnnotationKey: "true", + constants.BatcherMaxLatencyInternalAnnotationKey: "100", + constants.BatcherMaxBatchSizeInternalAnnotationKey: "30", + }, + Labels: map[string]string{ + "serving.kserve.io/inferenceservice": "sklearn", + constants.KServiceModelLabel: "sklearn", + constants.KServiceEndpointLabel: "default", + constants.KServiceComponentLabel: "predictor", + }, + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "sklearn", + ReadinessProbe: &v1.Probe{ + ProbeHandler: v1.ProbeHandler{ + TCPSocket: &v1.TCPSocketAction{ + Port: intstr.IntOrString{ + IntVal: 8080, + }, + }, + }, + InitialDelaySeconds: 0, + TimeoutSeconds: 1, + PeriodSeconds: 10, + SuccessThreshold: 1, + FailureThreshold: 3, + }, + }, + { + Name: "queue-proxy", + Env: []v1.EnvVar{{Name: "SERVING_READINESS_PROBE", Value: "{\"tcpSocket\":{\"port\":8080},\"timeoutSeconds\":1,\"periodSeconds\":10,\"successThreshold\":1,\"failureThreshold\":3}"}}, + }, + }, + }, + }, + expected: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "deployment", + Annotations: map[string]string{ + constants.BatcherInternalAnnotationKey: "true", + constants.BatcherMaxLatencyInternalAnnotationKey: "100", + constants.BatcherMaxBatchSizeInternalAnnotationKey: "30", + }, + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "sklearn", + ReadinessProbe: &v1.Probe{ + ProbeHandler: v1.ProbeHandler{ + TCPSocket: &v1.TCPSocketAction{ + Port: intstr.IntOrString{ + IntVal: 8080, + }, + }, + }, + InitialDelaySeconds: 0, + TimeoutSeconds: 1, + PeriodSeconds: 10, + SuccessThreshold: 1, + FailureThreshold: 3, + }, + }, + { + Name: "queue-proxy", + Env: []v1.EnvVar{{Name: "SERVING_READINESS_PROBE", Value: "{\"tcpSocket\":{\"port\":8080},\"timeoutSeconds\":1,\"periodSeconds\":10,\"successThreshold\":1,\"failureThreshold\":3}"}}, + }, + { + Name: constants.AgentContainerName, + Image: loggerConfig.Image, + Args: []string{ + BatcherEnableFlag, + BatcherArgumentMaxBatchSize, + "30", + BatcherArgumentMaxLatency, + "100", + }, + Ports: []v1.ContainerPort{ + { + Name: "agent-port", + ContainerPort: constants.InferenceServiceDefaultAgentPort, + Protocol: "TCP", + }, + }, + Env: []v1.EnvVar{{Name: "SERVING_READINESS_PROBE", Value: "{\"tcpSocket\":{\"port\":8080},\"timeoutSeconds\":1,\"periodSeconds\":10,\"successThreshold\":1,\"failureThreshold\":3}"}}, + Resources: agentResourceRequirement, + ReadinessProbe: &v1.Probe{ + ProbeHandler: v1.ProbeHandler{ + HTTPGet: &v1.HTTPGetAction{ + HTTPHeaders: []v1.HTTPHeader{ + { + Name: "K-Network-Probe", + Value: "queue", + }, + }, + Port: intstr.FromInt(9081), + Path: "/", + Scheme: "HTTP", + }, + }, + }, + }, + }, + }, + }, + }, + "DoNotAddBatcher": { + original: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "deployment", + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{{ + Name: "sklearn", + }}, + }, + }, + expected: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "deployment", + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{{ + Name: "sklearn", + }}, + }, + }, + }, } credentialBuilder := credentials.NewCredentialBulder(c, &v1.ConfigMap{ @@ -387,3 +525,143 @@ func TestAgentInjector(t *testing.T) { } } } + +func TestGetLoggerConfigs(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + configMap *v1.ConfigMap + matchers []types.GomegaMatcher + }{ + { + name: "Valid Logger Config", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + LoggerConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/logger:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200Mi", + "MemoryLimit": "1Gi" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&LoggerConfig{ + Image: "gcr.io/kfserving/logger:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200Mi", + MemoryLimit: "1Gi", + }), + gomega.BeNil(), + }, + }, + { + name: "Invalid Resource Value", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + LoggerConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/logger:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200mc", + "MemoryLimit": "1Gi" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&LoggerConfig{ + Image: "gcr.io/kfserving/logger:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200mc", + MemoryLimit: "1Gi", + }), + gomega.HaveOccurred(), + }, + }, + } + + for _, tc := range cases { + loggerConfigs, err := getLoggerConfigs(tc.configMap) + g.Expect(err).Should(tc.matchers[1]) + g.Expect(loggerConfigs).Should(tc.matchers[0]) + } +} + +func TestGetAgentConfigs(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + configMap *v1.ConfigMap + matchers []types.GomegaMatcher + }{ + { + name: "Valid Agent Config", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + constants.AgentConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/agent:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200Mi", + "MemoryLimit": "1Gi" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&AgentConfig{ + Image: "gcr.io/kfserving/agent:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200Mi", + MemoryLimit: "1Gi", + }), + gomega.BeNil(), + }, + }, + { + name: "Invalid Resource Value", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + constants.AgentConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/agent:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200mc", + "MemoryLimit": "1Gi" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&AgentConfig{ + Image: "gcr.io/kfserving/agent:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200mc", + MemoryLimit: "1Gi", + }), + gomega.HaveOccurred(), + }, + }, + } + + for _, tc := range cases { + loggerConfigs, err := getAgentConfigs(tc.configMap) + g.Expect(err).Should(tc.matchers[1]) + g.Expect(loggerConfigs).Should(tc.matchers[0]) + } +} diff --git a/pkg/webhook/admission/pod/batcher_injector_test.go b/pkg/webhook/admission/pod/batcher_injector_test.go index 0b825bf6e54..8b0db047c8d 100644 --- a/pkg/webhook/admission/pod/batcher_injector_test.go +++ b/pkg/webhook/admission/pod/batcher_injector_test.go @@ -23,6 +23,8 @@ import ( "knative.dev/pkg/kmp" "github.com/kserve/kserve/pkg/constants" + "github.com/onsi/gomega" + "github.com/onsi/gomega/types" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -150,3 +152,73 @@ func TestBatcherInjector(t *testing.T) { } } } + +func TestGetBatcherConfigs(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + configMap *v1.ConfigMap + matchers []types.GomegaMatcher + }{ + { + name: "Valid Batcher Config", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + BatcherConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/batcher:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200Mi", + "MemoryLimit": "1Gi" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&BatcherConfig{ + Image: "gcr.io/kfserving/batcher:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200Mi", + MemoryLimit: "1Gi", + }), + gomega.BeNil(), + }, + }, + { + name: "Invalid Resource Value", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + BatcherConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/batcher:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200mc", + "MemoryLimit": "1Gi" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&BatcherConfig{ + Image: "gcr.io/kfserving/batcher:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200mc", + MemoryLimit: "1Gi", + }), + gomega.HaveOccurred(), + }, + }, + } + + for _, tc := range cases { + loggerConfigs, err := getBatcherConfigs(tc.configMap) + g.Expect(err).Should(tc.matchers[1]) + g.Expect(loggerConfigs).Should(tc.matchers[0]) + } +} diff --git a/pkg/webhook/admission/pod/mutator_test.go b/pkg/webhook/admission/pod/mutator_test.go new file mode 100644 index 00000000000..fe3b58a65a1 --- /dev/null +++ b/pkg/webhook/admission/pod/mutator_test.go @@ -0,0 +1,297 @@ +package pod + +import ( + "context" + "encoding/json" + "github.com/golang/protobuf/proto" + "github.com/google/uuid" + "github.com/kserve/kserve/pkg/constants" + "github.com/onsi/gomega" + gomegaTypes "github.com/onsi/gomega/types" + "gomodules.xyz/jsonpatch/v2" + admissionv1 "k8s.io/api/admission/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + "testing" +) + +func TestMutator_Handle(t *testing.T) { + g := gomega.NewGomegaWithT(t) + kserveNamespace := v1.Namespace{ + TypeMeta: metav1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.KServeNamespace, + }, + Spec: v1.NamespaceSpec{}, + Status: v1.NamespaceStatus{}, + } + + if err := c.Create(context.TODO(), &kserveNamespace); err != nil { + t.Errorf("failed to create namespace: %v", err) + } + + mutator := Mutator{} + if err := mutator.InjectClient(c); err != nil { + t.Errorf("failed to inject client: %v", err) + } + + decoder, _ := admission.NewDecoder(c.Scheme()) + if err := mutator.InjectDecoder(decoder); err != nil { + t.Errorf("failed to inject decoder: %v", err) + } + + cases := map[string]struct { + configMap v1.ConfigMap + request admission.Request + pod v1.Pod + matcher gomegaTypes.GomegaMatcher + }{ + "should not mutate non isvc pods": { + configMap: v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{ + Kind: "ConfigMap", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceConfigMapName, + Namespace: constants.KServeNamespace, + }, + Immutable: nil, + Data: map[string]string{ + StorageInitializerConfigMapKeyName: `{ + "image" : "kserve/storage-initializer:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1", + "storageSpecSecretName": "storage-config" + }`, + LoggerConfigMapKeyName: `{ + "image" : "kserve/agent:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1", + "defaultUrl": "http://default-broker" + }`, + BatcherConfigMapKeyName: `{ + "image" : "kserve/agent:latest", + "memoryRequest": "1Gi", + "memoryLimit": "1Gi", + "cpuRequest": "1", + "cpuLimit": "1" + }`, + constants.AgentConfigMapKeyName: `{ + "image" : "kserve/agent:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1" + }`, + }, + BinaryData: nil, + }, + request: admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + UID: types.UID(uuid.NewString()), + Kind: metav1.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "Pod", + }, + Resource: metav1.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "pods", + }, + SubResource: "", + RequestKind: &metav1.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "Pod", + }, + RequestResource: &metav1.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "pods", + }, + RequestSubResource: "", + Name: "", + Namespace: "default", + Operation: admissionv1.Create, + Object: runtime.RawExtension{}, + OldObject: runtime.RawExtension{}, + DryRun: nil, + Options: runtime.RawExtension{}, + }, + }, + pod: v1.Pod{ + TypeMeta: metav1.TypeMeta{ + Kind: "Pod", + APIVersion: "v1", + }, + }, + matcher: gomega.Equal(admission.Response{ + Patches: nil, + AdmissionResponse: admissionv1.AdmissionResponse{ + UID: "", + Allowed: true, + Result: &metav1.Status{ + TypeMeta: metav1.TypeMeta{}, + ListMeta: metav1.ListMeta{}, + Status: "", + Message: "", + Reason: "", + Details: nil, + Code: 200, + }, + Patch: nil, + PatchType: nil, + AuditAnnotations: nil, + Warnings: nil, + }, + }), + }, + "should mutate isvc pods": { + configMap: v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{ + Kind: "ConfigMap", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceConfigMapName, + Namespace: constants.KServeNamespace, + }, + Immutable: nil, + Data: map[string]string{ + StorageInitializerConfigMapKeyName: `{ + "image" : "kserve/storage-initializer:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1", + "storageSpecSecretName": "storage-config" + }`, + LoggerConfigMapKeyName: `{ + "image" : "kserve/agent:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1", + "defaultUrl": "http://default-broker" + }`, + BatcherConfigMapKeyName: `{ + "image" : "kserve/agent:latest", + "memoryRequest": "1Gi", + "memoryLimit": "1Gi", + "cpuRequest": "1", + "cpuLimit": "1" + }`, + constants.AgentConfigMapKeyName: `{ + "image" : "kserve/agent:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1" + }`, + }, + BinaryData: nil, + }, + request: admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + UID: types.UID(uuid.NewString()), + Kind: metav1.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "Pod", + }, + Resource: metav1.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "pods", + }, + SubResource: "", + RequestKind: &metav1.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "Pod", + }, + RequestResource: &metav1.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "pods", + }, + RequestSubResource: "", + Name: "", + Namespace: "default", + Operation: admissionv1.Create, + Object: runtime.RawExtension{}, + OldObject: runtime.RawExtension{}, + DryRun: nil, + Options: runtime.RawExtension{}, + }, + }, + pod: v1.Pod{ + TypeMeta: metav1.TypeMeta{ + Kind: "Pod", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + constants.InferenceServicePodLabelKey: "", + }, + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: constants.InferenceServiceContainerName, + }, + }, + }, + }, + matcher: gomega.Equal(admission.Response{ + Patches: []jsonpatch.JsonPatchOperation{ + { + Operation: "add", + Path: "/metadata/namespace", + Value: "default", + }, + }, + AdmissionResponse: admissionv1.AdmissionResponse{ + UID: "", + Allowed: true, + Result: nil, + Patch: nil, + PatchType: (*admissionv1.PatchType)(proto.String(string(admissionv1.PatchTypeJSONPatch))), + AuditAnnotations: nil, + Warnings: nil, + }, + }), + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + if err := c.Create(context.TODO(), &tc.configMap); err != nil { + t.Errorf("failed to create config map: %v", err) + } + byteData, err := json.Marshal(tc.pod) + if err != nil { + t.Errorf("failed to marshal pod data: %v", err) + } + tc.request.Object.Raw = byteData + res := mutator.Handle(context.TODO(), tc.request) + g.Expect(res).Should(tc.matcher) + if err := c.Delete(context.TODO(), &tc.configMap); err != nil { + t.Errorf("failed to delete configmap %v", err) + } + }) + } + +} diff --git a/pkg/webhook/admission/pod/storage_initializer_injector_test.go b/pkg/webhook/admission/pod/storage_initializer_injector_test.go index 63b3d584e68..06158575bce 100644 --- a/pkg/webhook/admission/pod/storage_initializer_injector_test.go +++ b/pkg/webhook/admission/pod/storage_initializer_injector_test.go @@ -28,6 +28,7 @@ import ( "github.com/kserve/kserve/pkg/credentials/gcs" "github.com/kserve/kserve/pkg/credentials/s3" "github.com/onsi/gomega" + "github.com/onsi/gomega/types" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -1042,3 +1043,115 @@ func TestStorageInitializerConfigmap(t *testing.T) { } } } + +func TestGetStorageInitializerConfigs(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + configMap *v1.ConfigMap + matchers []types.GomegaMatcher + }{ + { + name: "Valid Storage Initializer Config", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + StorageInitializerConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/storage-initializer:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200Mi", + "MemoryLimit": "1Gi", + "StorageSpecSecretName": "storage-secret" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&StorageInitializerConfig{ + Image: "gcr.io/kfserving/storage-initializer:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200Mi", + MemoryLimit: "1Gi", + StorageSpecSecretName: "storage-secret", + }), + gomega.BeNil(), + }, + }, + { + name: "Invalid Resource Value", + configMap: &v1.ConfigMap{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{}, + Data: map[string]string{ + StorageInitializerConfigMapKeyName: `{ + "Image": "gcr.io/kfserving/storage-initializer:latest", + "CpuRequest": "100m", + "CpuLimit": "1", + "MemoryRequest": "200MC", + "MemoryLimit": "1Gi", + "StorageSpecSecretName": "storage-secret" + }`, + }, + BinaryData: map[string][]byte{}, + }, + matchers: []types.GomegaMatcher{ + gomega.Equal(&StorageInitializerConfig{ + Image: "gcr.io/kfserving/storage-initializer:latest", + CpuRequest: "100m", + CpuLimit: "1", + MemoryRequest: "200MC", + MemoryLimit: "1Gi", + StorageSpecSecretName: "storage-secret", + }), + gomega.HaveOccurred(), + }, + }, + } + + for _, tc := range cases { + loggerConfigs, err := getStorageInitializerConfigs(tc.configMap) + g.Expect(err).Should(tc.matchers[1]) + g.Expect(loggerConfigs).Should(tc.matchers[0]) + } +} + +func TestParsePvcURI(t *testing.T) { + g := gomega.NewGomegaWithT(t) + cases := []struct { + name string + uri string + matchers []types.GomegaMatcher + }{ + { + name: "Valid PVC URI", + uri: "pvc://test/model/model1", + matchers: []types.GomegaMatcher{ + gomega.Equal("test"), + gomega.Equal("model/model1"), + gomega.BeNil(), + }, + }, + { + name: "Valid PVC URI with Shortest Path", + uri: "pvc://test", + matchers: []types.GomegaMatcher{ + gomega.Equal("test"), + gomega.Equal(""), + gomega.BeNil(), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + pvcName, pvcPath, err := parsePvcURI(tc.uri) + g.Expect(pvcName).Should(tc.matchers[0]) + g.Expect(pvcPath).Should(tc.matchers[1]) + g.Expect(err).Should(tc.matchers[2]) + }) + + } +} diff --git a/python/kserve/test/test_creds_utils.py b/python/kserve/test/test_creds_utils.py index b3858d36e41..0e7c53d2408 100644 --- a/python/kserve/test/test_creds_utils.py +++ b/python/kserve/test/test_creds_utils.py @@ -12,11 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import tempfile from unittest import mock -from kubernetes.client import V1ServiceAccountList, V1ServiceAccount, V1ObjectMeta +import pytest +from kubernetes.client import (V1ObjectMeta, V1ServiceAccount, + V1ServiceAccountList, rest) -from kserve.api.creds_utils import check_sa_exists +from kserve import constants +from kserve.api.creds_utils import (check_sa_exists, create_secret, + create_service_account, + get_creds_name_from_config_map, + patch_service_account, + set_azure_credentials, set_gcs_credentials, + set_s3_credentials, set_service_account) @mock.patch('kubernetes.client.CoreV1Api.list_namespaced_service_account') @@ -31,3 +41,169 @@ def test_check_sa_exists(mock_client): assert check_sa_exists('kubeflow', 'a') is True assert check_sa_exists('kubeflow', 'b') is True assert check_sa_exists('kubeflow', 'c') is False + + +@mock.patch('kubernetes.client.CoreV1Api.create_namespaced_service_account') +def test_create_service_account(mock_client): + sa_name = "test" + namespace = "kserve-test" + secret_name = "test_secret" + create_service_account(secret_name, namespace, sa_name) + mock_client.assert_called_once() + + mock_client.side_effect = rest.ApiException('foo') + with pytest.raises(RuntimeError): + sa_name = "test" + namespace = "kserve-test" + secret_name = "test_secret" + create_service_account(secret_name, namespace, sa_name) + + +@mock.patch('kubernetes.client.CoreV1Api.patch_namespaced_service_account') +def test_patch_service_account(mock_client): + sa_name = "test" + namespace = "kserve-test" + secret_name = "test_secret" + patch_service_account(secret_name, namespace, sa_name) + mock_client.assert_called_once() + + mock_client.side_effect = rest.ApiException('foo') + with pytest.raises(RuntimeError): + sa_name = "test" + namespace = "kserve-test" + secret_name = "test_secret" + patch_service_account(secret_name, namespace, sa_name) + + +@mock.patch('kubernetes.client.CoreV1Api.create_namespaced_secret') +def test_create_secret(mock_create_secret): + namespace = "test" + secret_name = "test-secret" + mock_create_secret.return_value = mock.Mock(**{"metadata.name": secret_name}) + assert create_secret(namespace) == secret_name + + with pytest.raises(RuntimeError): + mock_create_secret.side_effect = rest.ApiException('foo') + create_secret(namespace) + + +@mock.patch('kserve.api.creds_utils.create_service_account') +@mock.patch('kserve.api.creds_utils.patch_service_account') +@mock.patch('kserve.api.creds_utils.check_sa_exists') +def test_set_service_account(mock_check_sa_exists, mock_patch_service_account, mock_create_service_account): + namespace = "test" + service_account = V1ServiceAccount() + secret_name = "test-secret" + mock_check_sa_exists.return_value = True + set_service_account(namespace, service_account, secret_name) + mock_patch_service_account.assert_called_once() + + mock_check_sa_exists.return_value = False + set_service_account(namespace, service_account, secret_name) + mock_create_service_account.assert_called_once() + + +@mock.patch('kubernetes.client.CoreV1Api.read_namespaced_config_map') +def test_get_creds_name_from_config_map(mock_read_config_map): + mock_read_config_map.return_value = mock.Mock(**{"data": {"credentials": """{ + "gcs": {"gcsCredentialFileName": "gcs_cred.json"}, + "s3": {"s3AccessKeyIDName": "s3_access_key.json", + "s3SecretAccessKeyName": "s3_secret.json"}}""" + }}) + test_cases = {'gcsCredentialFileName': 'gcs_cred.json', + 's3AccessKeyIDName': 's3_access_key.json', + 's3SecretAccessKeyName': 's3_secret.json'} + for cred, result in test_cases.items(): + assert get_creds_name_from_config_map(cred) == result + + with pytest.raises(RuntimeError): + get_creds_name_from_config_map("invalidCred") + + mock_read_config_map.side_effect = rest.ApiException('foo') + assert get_creds_name_from_config_map('gcsCredentialFileName') is None + + +@mock.patch('kserve.api.creds_utils.set_service_account') +@mock.patch('kserve.api.creds_utils.create_secret') +@mock.patch('kserve.api.creds_utils.get_creds_name_from_config_map') +def test_set_gcs_credentials(mock_get_creds_name, mock_create_secret, mock_set_service_account): + namespace = "test" + service_account = V1ServiceAccount() + temp_cred_file = tempfile.NamedTemporaryFile(suffix=".json") + cred_file_name = temp_cred_file.name + mock_get_creds_name.return_value = cred_file_name + mock_create_secret.return_value = "test-secret" + set_gcs_credentials(namespace, cred_file_name, service_account) + mock_get_creds_name.assert_called() + mock_create_secret.assert_called() + mock_set_service_account.assert_called() + + mock_get_creds_name.return_value = None + set_gcs_credentials(namespace, cred_file_name, service_account) + mock_get_creds_name.assert_called() + mock_create_secret.assert_called() + mock_set_service_account.assert_called() + + +@mock.patch('kserve.api.creds_utils.set_service_account') +@mock.patch('kserve.api.creds_utils.create_secret') +@mock.patch('kserve.api.creds_utils.get_creds_name_from_config_map') +def test_set_s3_credentials(mock_get_creds_name, mock_create_secret, mock_set_service_account): + namespace = "test" + endpoint = "https://s3.aws.com" + region = "ap-south-1" + use_https = True + verfify_ssl = True + cabundle = "/user/test/cert.pem" + data = { + constants.S3_ACCESS_KEY_ID_DEFAULT_NAME: "XXXXXXXXXXXX", + constants.S3_SECRET_ACCESS_KEY_DEFAULT_NAME: "XXXXXXXXXXXX", + } + annotations = {constants.KSERVE_GROUP + "/s3-endpoint": endpoint, + constants.KSERVE_GROUP + "/s3-region": region, + constants.KSERVE_GROUP + "/s3-usehttps": use_https, + constants.KSERVE_GROUP + "/s3-verifyssl": verfify_ssl, + constants.KSERVE_GROUP + "/s3-cabundle": cabundle + } + creds_str = b""" + [default] + aws_access_key_id = XXXXXXXXXXXX + aws_secret_access_key = XXXXXXXXXXXX + """ + + with tempfile.NamedTemporaryFile() as creds_file: + creds_file.write(creds_str) + creds_file.seek(0) + mock_get_creds_name.return_value = None + mock_create_secret.return_value = "test-secret" + set_s3_credentials(namespace, creds_file.name, V1ServiceAccount(), s3_endpoint=endpoint, + s3_region=region, s3_use_https=use_https, s3_verify_ssl=verfify_ssl, + s3_cabundle=cabundle) + mock_create_secret.assert_called_with(namespace=namespace, annotations=annotations, data=data) + mock_get_creds_name.asset_called() + mock_set_service_account.assert_called() + + +@mock.patch('kserve.api.creds_utils.set_service_account') +@mock.patch('kserve.api.creds_utils.create_secret') +def test_set_azure_credentials(mock_create_secret, mock_set_service_account): + namespace = "test" + creds = { + "clientId": "XXXXXXXXXXX", + "clientSecret": "XXXXXXXXXXX", + "subscriptionId": "XXXXXXXXXXX", + "tenantId": "XXXXXXXXXXX" + } + data = { + 'AZ_CLIENT_ID': creds['clientId'], + 'AZ_CLIENT_SECRET': creds['clientSecret'], + 'AZ_SUBSCRIPTION_ID': creds['subscriptionId'], + 'AZ_TENANT_ID': creds['tenantId'], + } + with tempfile.NamedTemporaryFile(suffix=".json") as creds_file: + creds_file.write(json.dumps(creds).encode("utf-8")) + creds_file.seek(0) + mock_create_secret.return_value = "test-secret" + set_azure_credentials(namespace, creds_file.name, V1ServiceAccount()) + mock_create_secret.assert_called_with(namespace=namespace, data=data) + mock_set_service_account.assert_called()