diff --git a/backend/src/apiserver/client_manager/client_manager.go b/backend/src/apiserver/client_manager/client_manager.go index e74310f3e96..aabf42d2669 100644 --- a/backend/src/apiserver/client_manager/client_manager.go +++ b/backend/src/apiserver/client_manager/client_manager.go @@ -208,7 +208,8 @@ func (c *ClientManager) init() { c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams) - newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort()) + newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled()) + if err != nil { glog.Fatalf("Failed to create metadata client. Error: %v", err) } diff --git a/backend/src/apiserver/common/config.go b/backend/src/apiserver/common/config.go index c0763ed0777..2d5adab1393 100644 --- a/backend/src/apiserver/common/config.go +++ b/backend/src/apiserver/common/config.go @@ -34,6 +34,7 @@ const ( TokenReviewAudience string = "TOKEN_REVIEW_AUDIENCE" MetadataGrpcServiceServiceHost string = "METADATA_GRPC_SERVICE_SERVICE_HOST" MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT" + MetadataTLSEnabled string = "METADATA_TLS_ENABLED" SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS" ) @@ -142,3 +143,7 @@ func GetMetadataGrpcServiceServicePort() string { func GetSignedURLExpiryTimeSeconds() int { return GetIntConfigWithDefault(SignedURLExpiryTimeSeconds, DefaultSignedURLExpiryTimeSeconds) } + +func GetMetadataTLSEnabled() bool { + return GetBoolConfigWithDefault(MetadataTLSEnabled, DefaultMetadataTLSEnabled) +} diff --git a/backend/src/apiserver/common/const.go b/backend/src/apiserver/common/const.go index 168675e294c..282488e8d38 100644 --- a/backend/src/apiserver/common/const.go +++ b/backend/src/apiserver/common/const.go @@ -57,6 +57,7 @@ const DefaultTokenReviewAudience string = "pipelines.kubeflow.org" const ( DefaultMetadataGrpcServiceServiceHost = "metadata-grpc-service" DefaultMetadataGrpcServiceServicePort = "8080" + DefaultMetadataTLSEnabled = false ) const ( diff --git a/backend/src/v2/cmd/driver/main.go b/backend/src/v2/cmd/driver/main.go index 98127c28446..95faea8a3ba 100644 --- a/backend/src/v2/cmd/driver/main.go +++ b/backend/src/v2/cmd/driver/main.go @@ -70,6 +70,7 @@ var ( conditionPath = flag.String("condition_path", "", "Condition output path") mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") + metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').") ) // func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) { @@ -154,6 +155,11 @@ func drive() (err error) { return err } + metadataTLSEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr) + if err != nil { + return err + } + cacheClient, err := cacheutils.NewClient(mlPipelineServiceTLSEnabled) if err != nil { return err @@ -167,6 +173,9 @@ func drive() (err error) { DAGExecutionID: *dagExecutionID, IterationIndex: *iterationIndex, MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled, + MLMDServerAddress: *mlmdServerAddress, + MLMDServerPort: *mlmdServerPort, + MLMDTLSEnabled: metadataTLSEnabled, } var execution *driver.Execution var driverErr error @@ -292,5 +301,11 @@ func newMlmdClient() (*metadata.Client, error) { mlmdConfig.Address = *mlmdServerAddress mlmdConfig.Port = *mlmdServerPort } - return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port) + + tlsEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr) + if err != nil { + return nil, err + } + + return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled) } diff --git a/backend/src/v2/cmd/launcher-v2/main.go b/backend/src/v2/cmd/launcher-v2/main.go index 3ac4245142f..1e11edb8b95 100644 --- a/backend/src/v2/cmd/launcher-v2/main.go +++ b/backend/src/v2/cmd/launcher-v2/main.go @@ -43,6 +43,7 @@ var ( mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.") mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.") mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") + metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').") ) func main() { @@ -71,6 +72,12 @@ func run() error { if err != nil { return err } + + metadataServiceTLSEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr) + if err != nil { + return err + } + launcherV2Opts := &component.LauncherV2Options{ Namespace: namespace, PodName: *podName, @@ -80,6 +87,7 @@ func run() error { PipelineName: *pipelineName, RunID: *runID, MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled, + MetadataTLSEnabled: metadataServiceTLSEnabled, } switch *executorType { diff --git a/backend/src/v2/compiler/argocompiler/container.go b/backend/src/v2/compiler/argocompiler/container.go index 7b12ca174d1..87726854562 100644 --- a/backend/src/v2/compiler/argocompiler/container.go +++ b/backend/src/v2/compiler/argocompiler/container.go @@ -17,6 +17,7 @@ package argocompiler import ( wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/v2/component" k8score "k8s.io/api/core/v1" "os" @@ -163,6 +164,9 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { "--condition_path", outputPath(paramCondition), "--kubernetes_config", inputValue(paramKubernetesConfig), "--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled), + "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), + "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), + "--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()), }, Resources: driverResources, }, diff --git a/backend/src/v2/compiler/argocompiler/dag.go b/backend/src/v2/compiler/argocompiler/dag.go index 36a239667e3..da9889d9894 100644 --- a/backend/src/v2/compiler/argocompiler/dag.go +++ b/backend/src/v2/compiler/argocompiler/dag.go @@ -15,6 +15,7 @@ package argocompiler import ( "fmt" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "sort" "strconv" "strings" @@ -443,6 +444,9 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { "--iteration_count_path", outputPath(paramIterationCount), "--condition_path", outputPath(paramCondition), "--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled), + "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), + "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), + "--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()), }, Resources: driverResources, }, diff --git a/backend/src/v2/compiler/argocompiler/importer.go b/backend/src/v2/compiler/argocompiler/importer.go index 83ac6453b64..e84c2d673b1 100644 --- a/backend/src/v2/compiler/argocompiler/importer.go +++ b/backend/src/v2/compiler/argocompiler/importer.go @@ -16,6 +16,7 @@ package argocompiler import ( "fmt" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" @@ -76,10 +77,8 @@ func (c *workflowCompiler) addImporterTemplate() string { fmt.Sprintf("$(%s)", component.EnvPodName), "--pod_uid", fmt.Sprintf("$(%s)", component.EnvPodUID), - "--mlmd_server_address", - fmt.Sprintf("$(%s)", component.EnvMetadataHost), - "--mlmd_server_port", - fmt.Sprintf("$(%s)", component.EnvMetadataPort), + "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), + "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), } importerTemplate := &wfapi.Template{ Name: name, diff --git a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml index d4cd73085df..b15ccdbbddc 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml @@ -67,6 +67,12 @@ spec: - '{{inputs.parameters.kubernetes-config}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" + - "--metadataTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -312,6 +318,12 @@ spec: - '{{outputs.parameters.condition.path}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" + - "--metadataTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml index e285ad07188..c7ee5ed98bf 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml @@ -50,6 +50,12 @@ spec: - '{{inputs.parameters.kubernetes-config}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" + - "--metadataTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -242,6 +248,12 @@ spec: - '{{outputs.parameters.condition.path}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" + - "--metadataTLSEnabled" + - "false" env: - name: ML_PIPELINE_SERVICE_HOST value: ml-pipeline.kubeflow.svc.cluster.local diff --git a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml index 0e2d30a12b2..781b5be9006 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml @@ -38,9 +38,9 @@ spec: - --pod_uid - $(KFP_POD_UID) - --mlmd_server_address - - $(METADATA_GRPC_SERVICE_HOST) + - "metadata-grpc-service" - --mlmd_server_port - - $(METADATA_GRPC_SERVICE_PORT) + - "8080" command: - launcher-v2 env: @@ -120,6 +120,12 @@ spec: - '{{outputs.parameters.condition.path}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" + - "--metadataTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/component/importer_launcher.go b/backend/src/v2/component/importer_launcher.go index e6dae29d639..44c60229887 100644 --- a/backend/src/v2/component/importer_launcher.go +++ b/backend/src/v2/component/importer_launcher.go @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err) } - metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort) + metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled) if err != nil { return nil, err } diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index b7682d5a4e5..f57d675f616 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -54,6 +54,8 @@ type LauncherV2Options struct { RunID string // set to true if ml pipeline server is serving over tls MLPipelineTLSEnabled bool + // set to true if metadata server is serving over tls + MetadataTLSEnabled bool } type LauncherV2 struct { @@ -110,7 +112,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err) } - metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort) + metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled) if err != nil { return nil, err } diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index b2f0e15c6a0..a98a61f58ed 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -77,6 +77,13 @@ type Options struct { // set to true if ml pipeline server is serving over tls MLPipelineTLSEnabled bool + + MLMDServerAddress string + + MLMDServerPort string + + // set to true if MLMD server is serving over tls + MLMDTLSEnabled bool } // Identifying information used for error messages @@ -339,7 +346,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, nil } - podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled) + podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled) if err != nil { return execution, err } @@ -373,6 +380,9 @@ func initPodSpecPatch( pipelineName string, runID string, mlPipelineTLSEnabled bool, + mlmdServerAddress string, + mlmdServerPort string, + mlmdTLSEnabled bool, ) (*k8score.PodSpec, error) { executorInputJSON, err := protojson.Marshal(executorInput) if err != nil { @@ -407,10 +417,9 @@ func initPodSpecPatch( fmt.Sprintf("$(%s)", component.EnvPodName), "--pod_uid", fmt.Sprintf("$(%s)", component.EnvPodUID), - "--mlmd_server_address", - fmt.Sprintf("$(%s)", component.EnvMetadataHost), - "--mlmd_server_port", - fmt.Sprintf("$(%s)", component.EnvMetadataPort), + "--mlmd_server_address", mlmdServerAddress, + "--mlmd_server_port", mlmdServerPort, + "--metadataTLSEnabled", fmt.Sprintf("%v", mlmdTLSEnabled), "--mlPipelineServiceTLSEnabled", fmt.Sprintf("%v", mlPipelineTLSEnabled), "--", // separater before user command and args diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index 34ed4d13bb3..523b6b3bf9e 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -241,7 +241,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false) if tt.wantErr { assert.Nil(t, podSpec) assert.NotNil(t, err) @@ -403,7 +403,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false) assert.Nil(t, err) assert.NotEmpty(t, podSpec) podSpecString, err := json.Marshal(podSpec) diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index eaaae44896a..c548081c670 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -18,8 +18,11 @@ package metadata import ( "context" + "crypto/tls" "errors" "fmt" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "path" "strconv" "strings" @@ -105,14 +108,21 @@ type Client struct { } // NewClient creates a Client given the MLMD server address and port. -func NewClient(serverAddress, serverPort string) (*Client, error) { +func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) { opts := []grpc_retry.CallOption{ grpc_retry.WithMax(mlmdClientSideMaxRetries), grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)), grpc_retry.WithCodes(codes.Aborted), } + + creds := insecure.NewCredentials() + if tlsEnabled { + config := &tls.Config{} + creds = credentials.NewTLS(config) + } + conn, err := grpc.Dial(fmt.Sprintf("%s:%s", serverAddress, serverPort), - grpc.WithInsecure(), + grpc.WithTransportCredentials(creds), grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)), grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)), ) diff --git a/backend/src/v2/metadata/client_test.go b/backend/src/v2/metadata/client_test.go index ea3bf34dde1..d8e25efef3c 100644 --- a/backend/src/v2/metadata/client_test.go +++ b/backend/src/v2/metadata/client_test.go @@ -84,7 +84,7 @@ func Test_GetPipeline(t *testing.T) { runUuid, err := uuid.NewRandom() fatalIf(err) runId := runUuid.String() - client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort) + client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false) fatalIf(err) mlmdClient, err := NewTestMlmdClient() fatalIf(err) @@ -135,7 +135,7 @@ func Test_GetPipeline_Twice(t *testing.T) { runUuid, err := uuid.NewRandom() fatalIf(err) runId := runUuid.String() - client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort) + client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false) fatalIf(err) pipeline, err := client.GetPipeline(ctx, "get-pipeline-test", runId, namespace, runResource, pipelineRoot, "") @@ -177,7 +177,7 @@ func Test_GetPipelineConcurrently(t *testing.T) { t.Skip("Temporarily disable the test that requires cluster connection.") // This test depends on a MLMD grpc server running at localhost:8080. - client, err := metadata.NewClient("localhost", "8080") + client, err := metadata.NewClient("localhost", "8080", false) if err != nil { t.Fatal(err) } @@ -281,7 +281,7 @@ func Test_DAG(t *testing.T) { func newLocalClientOrFatal(t *testing.T) *metadata.Client { t.Helper() - client, err := metadata.NewClient("localhost", "8080") + client, err := metadata.NewClient("localhost", "8080", false) if err != nil { t.Fatalf("metadata.NewClient failed: %v", err) }