From 474c55a53257faba91b78590359fd858621a7d15 Mon Sep 17 00:00:00 2001 From: Helber Belmiro Date: Wed, 30 Oct 2024 10:10:55 -0300 Subject: [PATCH] UPSTREAM: : Use DSPA custom ca cert on MLMD and Persistence Agent clients Signed-off-by: Helber Belmiro --- .../persistence/client/pipeline_client.go | 5 +++-- backend/src/agent/persistence/main.go | 7 +++++-- .../client_manager/client_manager.go | 2 +- backend/src/apiserver/common/config.go | 13 +++++++++++++ backend/src/common/util/service.go | 19 +++++++++++++++++-- backend/src/v2/cmd/driver/main.go | 4 +++- backend/src/v2/cmd/launcher-v2/main.go | 2 ++ .../src/v2/compiler/argocompiler/common.go | 7 ++++--- .../src/v2/compiler/argocompiler/container.go | 1 + backend/src/v2/compiler/argocompiler/dag.go | 1 + .../src/v2/compiler/argocompiler/importer.go | 4 ++++ .../create_mount_delete_dynamic_pvc.yaml | 4 ++++ .../testdata/create_pod_metadata.yaml | 4 ++++ .../argocompiler/testdata/hello_world.yaml | 4 ++++ .../argocompiler/testdata/importer.yaml | 4 ++++ backend/src/v2/component/importer_launcher.go | 2 +- backend/src/v2/component/launcher_v2.go | 3 ++- backend/src/v2/driver/driver.go | 6 +++++- backend/src/v2/driver/driver_test.go | 6 +++--- backend/src/v2/metadata/client.go | 18 ++++++++++++++++-- backend/src/v2/metadata/client_test.go | 8 ++++---- 21 files changed, 101 insertions(+), 23 deletions(-) diff --git a/backend/src/agent/persistence/client/pipeline_client.go b/backend/src/agent/persistence/client/pipeline_client.go index 10da3822c19..546c1783644 100644 --- a/backend/src/agent/persistence/client/pipeline_client.go +++ b/backend/src/agent/persistence/client/pipeline_client.go @@ -56,7 +56,8 @@ func NewPipelineClient( mlPipelineServiceName string, mlPipelineServiceHttpPort string, mlPipelineServiceGRPCPort string, - mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) { + mlPipelineServiceTLSEnabled bool, + caCertPath string) (*PipelineClient, error) { httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort) grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort) scheme := "http" @@ -68,7 +69,7 @@ func NewPipelineClient( return nil, errors.Wrapf(err, "Failed to initialize pipeline client. Error: %s", err.Error()) } - connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled) + connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled, caCertPath) if err != nil { return nil, errors.Wrapf(err, "Failed to get RPC connection. Error: %s", err.Error()) diff --git a/backend/src/agent/persistence/main.go b/backend/src/agent/persistence/main.go index 0f7b5762ef6..78559f9548c 100644 --- a/backend/src/agent/persistence/main.go +++ b/backend/src/agent/persistence/main.go @@ -48,6 +48,7 @@ var ( clientBurst int executionType string saTokenRefreshIntervalInSecs int64 + caCertPath string ) const ( @@ -68,6 +69,7 @@ const ( clientBurstFlagName = "clientBurst" executionTypeFlagName = "executionType" saTokenRefreshIntervalFlagName = "saTokenRefreshIntervalInSecs" + caCertPathFlagName = "caCertPath" ) const ( @@ -135,7 +137,8 @@ func main() { mlPipelineAPIServerName, mlPipelineServiceHttpPort, mlPipelineServiceGRPCPort, - mlPipelineServiceTLSEnabled) + mlPipelineServiceTLSEnabled, + caCertPath) if err != nil { log.Fatalf("Error creating ML pipeline API Server client: %v", err) } @@ -177,5 +180,5 @@ func init() { // TODO use viper/config file instead. Sync `saTokenRefreshIntervalFlagName` with the value from manifest file by using ENV var. flag.Int64Var(&saTokenRefreshIntervalInSecs, saTokenRefreshIntervalFlagName, DefaultSATokenRefresherIntervalInSecs, "Persistence agent service account token read interval in seconds. "+ "Defines how often `/var/run/secrets/kubeflow/tokens/kubeflow-persistent_agent-api-token` to be read") - + flag.StringVar(&caCertPath, caCertPathFlagName, "", "The path to the CA certificate.") } diff --git a/backend/src/apiserver/client_manager/client_manager.go b/backend/src/apiserver/client_manager/client_manager.go index 42c98ce467c..861f66e2d2f 100644 --- a/backend/src/apiserver/client_manager/client_manager.go +++ b/backend/src/apiserver/client_manager/client_manager.go @@ -208,7 +208,7 @@ func (c *ClientManager) init() { c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams) - newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled()) + newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled(), common.GetCaCertPath()) if err != nil { glog.Fatalf("Failed to create metadata client. Error: %v", err) diff --git a/backend/src/apiserver/common/config.go b/backend/src/apiserver/common/config.go index 2d5adab1393..21e1d62a994 100644 --- a/backend/src/apiserver/common/config.go +++ b/backend/src/apiserver/common/config.go @@ -36,6 +36,9 @@ const ( MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT" MetadataTLSEnabled string = "METADATA_TLS_ENABLED" SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS" + CaBundleMountPath string = "ARTIFACT_COPY_STEP_CABUNDLE_MOUNTPATH" + CaBundleConfigMapKey string = "ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_KEY" + CaBundleConfigMapName string = "ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_NAME" ) func IsPipelineVersionUpdatedByDefault() bool { @@ -147,3 +150,13 @@ func GetSignedURLExpiryTimeSeconds() int { func GetMetadataTLSEnabled() bool { return GetBoolConfigWithDefault(MetadataTLSEnabled, DefaultMetadataTLSEnabled) } + +func GetCaCertPath() string { + caBundleMountPath := GetStringConfigWithDefault(CaBundleMountPath, "") + if caBundleMountPath != "" { + caBundleConfigMapKey := GetStringConfigWithDefault(CaBundleConfigMapKey, "") + return caBundleMountPath + "/" + caBundleConfigMapKey + } else { + return "" + } +} diff --git a/backend/src/common/util/service.go b/backend/src/common/util/service.go index 8544963db3b..8f5b4486553 100644 --- a/backend/src/common/util/service.go +++ b/backend/src/common/util/service.go @@ -16,10 +16,12 @@ package util import ( "crypto/tls" + "crypto/x509" "fmt" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "net/http" + "os" "strings" "time" @@ -77,10 +79,23 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) ( return clientSet, config, namespace, nil } -func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) { +func GetRpcConnection(address string, tlsEnabled bool, caCertPath string) (*grpc.ClientConn, error) { creds := insecure.NewCredentials() if tlsEnabled { - config := &tls.Config{} + if caCertPath == "" { + return nil, errors.New("CA cert path is empty") + } + + caCert, err := os.ReadFile(caCertPath) + if err != nil { + return nil, errors.Wrap(err, "Encountered error when reading CA cert path for creating a metadata client.") + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + config := &tls.Config{ + RootCAs: caCertPool, + } creds = credentials.NewTLS(config) } diff --git a/backend/src/v2/cmd/driver/main.go b/backend/src/v2/cmd/driver/main.go index 95faea8a3ba..5602780ed6f 100644 --- a/backend/src/v2/cmd/driver/main.go +++ b/backend/src/v2/cmd/driver/main.go @@ -71,6 +71,7 @@ var ( mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').") + caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate.") ) // func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) { @@ -176,6 +177,7 @@ func drive() (err error) { MLMDServerAddress: *mlmdServerAddress, MLMDServerPort: *mlmdServerPort, MLMDTLSEnabled: metadataTLSEnabled, + CaCertPath: *caCertPath, } var execution *driver.Execution var driverErr error @@ -307,5 +309,5 @@ func newMlmdClient() (*metadata.Client, error) { return nil, err } - return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled) + return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled, *caCertPath) } diff --git a/backend/src/v2/cmd/launcher-v2/main.go b/backend/src/v2/cmd/launcher-v2/main.go index 1e11edb8b95..44eefcf56b2 100644 --- a/backend/src/v2/cmd/launcher-v2/main.go +++ b/backend/src/v2/cmd/launcher-v2/main.go @@ -44,6 +44,7 @@ var ( mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.") mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').") + caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate.") ) func main() { @@ -88,6 +89,7 @@ func run() error { RunID: *runID, MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled, MetadataTLSEnabled: metadataServiceTLSEnabled, + CaCertPath: *caCertPath, } switch *executorType { diff --git a/backend/src/v2/compiler/argocompiler/common.go b/backend/src/v2/compiler/argocompiler/common.go index 75684510511..c3c47143144 100644 --- a/backend/src/v2/compiler/argocompiler/common.go +++ b/backend/src/v2/compiler/argocompiler/common.go @@ -17,6 +17,7 @@ package argocompiler import ( "fmt" wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" k8score "k8s.io/api/core/v1" "os" "strconv" @@ -98,9 +99,9 @@ func GetMLPipelineServicePortGRPC() string { // ConfigureCABundle adds CABundle environment variables and volume mounts // if CA Bundle env vars are specified. func ConfigureCABundle(tmpl *wfapi.Template) { - caBundleCfgMapName := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_NAME") - caBundleCfgMapKey := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_KEY") - caBundleMountPath := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_MOUNTPATH") + caBundleCfgMapName := os.Getenv(common.CaBundleConfigMapName) + caBundleCfgMapKey := os.Getenv(common.CaBundleConfigMapKey) + caBundleMountPath := os.Getenv(common.CaBundleMountPath) if caBundleCfgMapName != "" && caBundleCfgMapKey != "" { caFile := fmt.Sprintf("%s/%s", caBundleMountPath, caBundleCfgMapKey) var certDirectories = []string{ diff --git a/backend/src/v2/compiler/argocompiler/container.go b/backend/src/v2/compiler/argocompiler/container.go index 8517330096e..920730bbd6f 100644 --- a/backend/src/v2/compiler/argocompiler/container.go +++ b/backend/src/v2/compiler/argocompiler/container.go @@ -171,6 +171,7 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), "--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()), + "--ca_cert_path", common.GetCaCertPath(), }, Resources: driverResources, }, diff --git a/backend/src/v2/compiler/argocompiler/dag.go b/backend/src/v2/compiler/argocompiler/dag.go index 6716dea9f4d..ebae7d2c667 100644 --- a/backend/src/v2/compiler/argocompiler/dag.go +++ b/backend/src/v2/compiler/argocompiler/dag.go @@ -447,6 +447,7 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), "--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()), + "--ca_cert_path", common.GetCaCertPath(), }, Resources: driverResources, }, diff --git a/backend/src/v2/compiler/argocompiler/importer.go b/backend/src/v2/compiler/argocompiler/importer.go index 2b49b09fc14..236112a891a 100644 --- a/backend/src/v2/compiler/argocompiler/importer.go +++ b/backend/src/v2/compiler/argocompiler/importer.go @@ -81,6 +81,7 @@ func (c *workflowCompiler) addImporterTemplate() string { "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), "--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()), + "--ca_cert_path", common.GetCaCertPath(), } importerTemplate := &wfapi.Template{ Name: name, @@ -101,6 +102,9 @@ func (c *workflowCompiler) addImporterTemplate() string { Resources: driverResources, }, } + + ConfigureCABundle(importerTemplate) + c.templates[name] = importerTemplate c.wf.Spec.Templates = append(c.wf.Spec.Templates, *importerTemplate) return name diff --git a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml index b15ccdbbddc..3feb4bdea31 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml @@ -73,6 +73,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -324,6 +326,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/compiler/argocompiler/testdata/create_pod_metadata.yaml b/backend/src/v2/compiler/argocompiler/testdata/create_pod_metadata.yaml index fa1318e78dc..d38a37a78b7 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/create_pod_metadata.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/create_pod_metadata.yaml @@ -57,6 +57,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -263,6 +265,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml index c7ee5ed98bf..de741e37601 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml @@ -56,6 +56,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -254,6 +256,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" env: - name: ML_PIPELINE_SERVICE_HOST value: ml-pipeline.kubeflow.svc.cluster.local diff --git a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml index 9651dad2449..a89d22dfccd 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml @@ -43,6 +43,8 @@ spec: - "8080" - --metadataTLSEnabled - "false" + - "--ca_cert_path" + - "" command: - launcher-v2 env: @@ -128,6 +130,8 @@ spec: - "8080" - "--metadataTLSEnabled" - "false" + - "--ca_cert_path" + - "" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/component/importer_launcher.go b/backend/src/v2/component/importer_launcher.go index 44c60229887..1e5f7a79e61 100644 --- a/backend/src/v2/component/importer_launcher.go +++ b/backend/src/v2/component/importer_launcher.go @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err) } - metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled) + metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled, launcherV2Opts.CaCertPath) if err != nil { return nil, err } diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index cacda3b6da1..3b62a115f93 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -56,6 +56,7 @@ type LauncherV2Options struct { MLPipelineTLSEnabled bool // set to true if metadata server is serving over tls MetadataTLSEnabled bool + CaCertPath string } type LauncherV2 struct { @@ -112,7 +113,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err) } - metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled) + metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled, opts.CaCertPath) if err != nil { return nil, err } diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index bd20aba5126..02f7afcf40f 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -82,6 +82,8 @@ type Options struct { // set to true if MLMD server is serving over tls MLMDTLSEnabled bool + + CaCertPath string } // Identifying information used for error messages @@ -344,7 +346,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, nil } - podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled) + podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled, opts.CaCertPath) if err != nil { return execution, err } @@ -381,6 +383,7 @@ func initPodSpecPatch( mlmdServerAddress string, mlmdServerPort string, mlmdTLSEnabled bool, + caCertPath string, ) (*k8score.PodSpec, error) { executorInputJSON, err := protojson.Marshal(executorInput) if err != nil { @@ -420,6 +423,7 @@ func initPodSpecPatch( "--metadataTLSEnabled", fmt.Sprintf("%v", mlmdTLSEnabled), "--mlPipelineServiceTLSEnabled", fmt.Sprintf("%v", mlPipelineTLSEnabled), + "--ca_cert_path", caCertPath, "--", // separater before user command and args } res := k8score.ResourceRequirements{ diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index 4c62d349708..5c7459f6dd3 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -244,7 +244,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false, "unused-ca-cert-path") if tt.wantErr { assert.Nil(t, podSpec) assert.NotNil(t, err) @@ -406,7 +406,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false, "unused-ca-cert-path") assert.Nil(t, err) assert.NotEmpty(t, podSpec) podSpecString, err := json.Marshal(podSpec) @@ -533,7 +533,7 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { { Name: "secret1", VolumeSource: k8score.VolumeSource{ - Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0],}, + Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0]}, }, }, }, diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index cf745d1f3ae..9f48f416795 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -19,9 +19,12 @@ package metadata import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" + "github.com/kubeflow/pipelines/backend/src/common/util" + "os" "path" "strconv" "strings" @@ -111,7 +114,7 @@ type Client struct { } // NewClient creates a Client given the MLMD server address and port. -func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) { +func NewClient(serverAddress, serverPort string, tlsEnabled bool, caCertPath string) (*Client, error) { opts := []grpc_retry.CallOption{ grpc_retry.WithMax(mlmdClientSideMaxRetries), grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)), @@ -120,8 +123,19 @@ func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, erro creds := insecure.NewCredentials() if tlsEnabled { + if caCertPath == "" { + return nil, errors.New("CA cert path is empty") + } + + caCert, err := os.ReadFile(caCertPath) + if err != nil { + return nil, util.Wrap(err, "metadata.NewClient() failed") + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + config := &tls.Config{ - InsecureSkipVerify: true, // This should be removed by https://issues.redhat.com/browse/RHOAIENG-13871 + RootCAs: caCertPool, } creds = credentials.NewTLS(config) } diff --git a/backend/src/v2/metadata/client_test.go b/backend/src/v2/metadata/client_test.go index 8c9a0ef0c57..e0974cecbdd 100644 --- a/backend/src/v2/metadata/client_test.go +++ b/backend/src/v2/metadata/client_test.go @@ -84,7 +84,7 @@ func Test_GetPipeline(t *testing.T) { runUuid, err := uuid.NewRandom() fatalIf(err) runId := runUuid.String() - client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false) + client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false, "unused-ca-cert-path") fatalIf(err) mlmdClient, err := NewTestMlmdClient() fatalIf(err) @@ -135,7 +135,7 @@ func Test_GetPipeline_Twice(t *testing.T) { runUuid, err := uuid.NewRandom() fatalIf(err) runId := runUuid.String() - client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false) + client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false, "unused-ca-cert-path") fatalIf(err) pipeline, err := client.GetPipeline(ctx, "get-pipeline-test", runId, namespace, runResource, pipelineRoot, "") @@ -177,7 +177,7 @@ func Test_GetPipelineConcurrently(t *testing.T) { t.Skip("Temporarily disable the test that requires cluster connection.") // This test depends on a MLMD grpc server running at localhost:8080. - client, err := metadata.NewClient("localhost", "8080", false) + client, err := metadata.NewClient("localhost", "8080", false, "unused-ca-cert-path") if err != nil { t.Fatal(err) } @@ -335,7 +335,7 @@ func Test_DAG(t *testing.T) { func newLocalClientOrFatal(t *testing.T) *metadata.Client { t.Helper() - client, err := metadata.NewClient("localhost", "8080", false) + client, err := metadata.NewClient("localhost", "8080", false, "unused-ca-cert-path") if err != nil { t.Fatalf("metadata.NewClient failed: %v", err) }