Skip to content

Commit

Permalink
UPSTREAM: <carry>: Use DSPA custom ca cert on MLMD and Persistence Ag…
Browse files Browse the repository at this point in the history
…ent clients

Signed-off-by: Helber Belmiro <[email protected]>
  • Loading branch information
hbelmiro committed Oct 31, 2024
1 parent 57a2823 commit bef5fa2
Show file tree
Hide file tree
Showing 21 changed files with 88 additions and 20 deletions.
5 changes: 3 additions & 2 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ func NewPipelineClient(
mlPipelineServiceName string,
mlPipelineServiceHttpPort string,
mlPipelineServiceGRPCPort string,
mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) {
mlPipelineServiceTLSEnabled bool,
caCertPath string) (*PipelineClient, error) {
httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort)
grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort)
scheme := "http"
Expand All @@ -68,7 +69,7 @@ func NewPipelineClient(
return nil, errors.Wrapf(err,
"Failed to initialize pipeline client. Error: %s", err.Error())
}
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled)
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled, caCertPath)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to get RPC connection. Error: %s", err.Error())
Expand Down
7 changes: 5 additions & 2 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
clientBurst int
executionType string
saTokenRefreshIntervalInSecs int64
caCertPath string
)

const (
Expand All @@ -68,6 +69,7 @@ const (
clientBurstFlagName = "clientBurst"
executionTypeFlagName = "executionType"
saTokenRefreshIntervalFlagName = "saTokenRefreshIntervalInSecs"
caCertPathFlagName = "caCertPath"
)

const (
Expand Down Expand Up @@ -135,7 +137,8 @@ func main() {
mlPipelineAPIServerName,
mlPipelineServiceHttpPort,
mlPipelineServiceGRPCPort,
mlPipelineServiceTLSEnabled)
mlPipelineServiceTLSEnabled,
caCertPath)
if err != nil {
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
}
Expand Down Expand Up @@ -177,5 +180,5 @@ func init() {
// TODO use viper/config file instead. Sync `saTokenRefreshIntervalFlagName` with the value from manifest file by using ENV var.
flag.Int64Var(&saTokenRefreshIntervalInSecs, saTokenRefreshIntervalFlagName, DefaultSATokenRefresherIntervalInSecs, "Persistence agent service account token read interval in seconds. "+
"Defines how often `/var/run/secrets/kubeflow/tokens/kubeflow-persistent_agent-api-token` to be read")

flag.StringVar(&caCertPath, caCertPathFlagName, "", "The path to the CA certificate.")
}
2 changes: 1 addition & 1 deletion backend/src/apiserver/client_manager/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (c *ClientManager) init() {

c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams)

newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled())
newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled(), common.GetCaCertPath())

if err != nil {
glog.Fatalf("Failed to create metadata client. Error: %v", err)
Expand Down
5 changes: 5 additions & 0 deletions backend/src/apiserver/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT"
MetadataTLSEnabled string = "METADATA_TLS_ENABLED"
SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS"
CaCertPath string = "CA_CERT_PATH"
)

func IsPipelineVersionUpdatedByDefault() bool {
Expand Down Expand Up @@ -147,3 +148,7 @@ func GetSignedURLExpiryTimeSeconds() int {
func GetMetadataTLSEnabled() bool {
return GetBoolConfigWithDefault(MetadataTLSEnabled, DefaultMetadataTLSEnabled)
}

func GetCaCertPath() string {
return GetStringConfigWithDefault(CaCertPath, DefaultCaCertPath)
}
2 changes: 2 additions & 0 deletions backend/src/apiserver/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ const (
MaxFileNameLength = 100
MaxFileLength = 32 << 20 // 32Mb
)

const DefaultCaCertPath = ""
19 changes: 17 additions & 2 deletions backend/src/common/util/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ package util

import (
"crypto/tls"
"crypto/x509"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -77,10 +79,23 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
return clientSet, config, namespace, nil
}

func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) {
func GetRpcConnection(address string, tlsEnabled bool, caCertPath string) (*grpc.ClientConn, error) {
creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{}
if caCertPath == "" {
return nil, errors.New("CA cert path is empty")
}

caCert, err := os.ReadFile(caCertPath)
if err != nil {
return nil, errors.Wrap(err, "metadata.NewClient() failed")
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)

config := &tls.Config{
RootCAs: caCertPool,
}
creds = credentials.NewTLS(config)
}

Expand Down
4 changes: 3 additions & 1 deletion backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ var (

mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate.")
)

// func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) {
Expand Down Expand Up @@ -176,6 +177,7 @@ func drive() (err error) {
MLMDServerAddress: *mlmdServerAddress,
MLMDServerPort: *mlmdServerPort,
MLMDTLSEnabled: metadataTLSEnabled,
CaCertPath: *caCertPath,
}
var execution *driver.Execution
var driverErr error
Expand Down Expand Up @@ -307,5 +309,5 @@ func newMlmdClient() (*metadata.Client, error) {
return nil, err
}

return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled)
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled, *caCertPath)
}
2 changes: 2 additions & 0 deletions backend/src/v2/cmd/launcher-v2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var (
mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.")
mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate.")
)

func main() {
Expand Down Expand Up @@ -88,6 +89,7 @@ func run() error {
RunID: *runID,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
MetadataTLSEnabled: metadataServiceTLSEnabled,
CaCertPath: *caCertPath,
}

switch *executorType {
Expand Down
1 change: 1 addition & 0 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
"--ca_cert_path", common.GetCaCertPath(),
},
Resources: driverResources,
},
Expand Down
1 change: 1 addition & 0 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
"--ca_cert_path", common.GetCaCertPath(),
},
Resources: driverResources,
},
Expand Down
1 change: 1 addition & 0 deletions backend/src/v2/compiler/argocompiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (c *workflowCompiler) addImporterTemplate() string {
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
"--ca_cert_path", common.GetCaCertPath(),
}
importerTemplate := &wfapi.Template{
Name: name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -324,6 +326,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -263,6 +265,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -254,6 +256,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
env:
- name: ML_PIPELINE_SERVICE_HOST
value: ml-pipeline.kubeflow.svc.cluster.local
Expand Down
4 changes: 4 additions & 0 deletions backend/src/v2/compiler/argocompiler/testdata/importer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ spec:
- "8080"
- --metadataTLSEnabled
- "false"
- "--ca_cert_path"
- ""
command:
- launcher-v2
env:
Expand Down Expand Up @@ -128,6 +130,8 @@ spec:
- "8080"
- "--metadataTLSEnabled"
- "false"
- "--ca_cert_path"
- ""
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled)
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled, launcherV2Opts.CaCertPath)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type LauncherV2Options struct {
MLPipelineTLSEnabled bool
// set to true if metadata server is serving over tls
MetadataTLSEnabled bool
CaCertPath string
}

type LauncherV2 struct {
Expand Down Expand Up @@ -112,7 +113,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled)
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled, opts.CaCertPath)
if err != nil {
return nil, err
}
Expand Down
6 changes: 5 additions & 1 deletion backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ type Options struct {

// set to true if MLMD server is serving over tls
MLMDTLSEnabled bool

CaCertPath string
}

// Identifying information used for error messages
Expand Down Expand Up @@ -344,7 +346,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
return execution, nil
}

podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled)
podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled, opts.CaCertPath)
if err != nil {
return execution, err
}
Expand Down Expand Up @@ -381,6 +383,7 @@ func initPodSpecPatch(
mlmdServerAddress string,
mlmdServerPort string,
mlmdTLSEnabled bool,
caCertPath string,
) (*k8score.PodSpec, error) {
executorInputJSON, err := protojson.Marshal(executorInput)
if err != nil {
Expand Down Expand Up @@ -420,6 +423,7 @@ func initPodSpecPatch(
"--metadataTLSEnabled", fmt.Sprintf("%v", mlmdTLSEnabled),
"--mlPipelineServiceTLSEnabled",
fmt.Sprintf("%v", mlPipelineTLSEnabled),
"--ca_cert_path", caCertPath,
"--", // separater before user command and args
}
res := k8score.ResourceRequirements{
Expand Down
6 changes: 3 additions & 3 deletions backend/src/v2/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false, "unused-ca-cert-path")
if tt.wantErr {
assert.Nil(t, podSpec)
assert.NotNil(t, err)
Expand Down Expand Up @@ -406,7 +406,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false, "unused-ca-cert-path")
assert.Nil(t, err)
assert.NotEmpty(t, podSpec)
podSpecString, err := json.Marshal(podSpec)
Expand Down Expand Up @@ -533,7 +533,7 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) {
{
Name: "secret1",
VolumeSource: k8score.VolumeSource{
Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0],},
Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0]},
},
},
},
Expand Down
18 changes: 16 additions & 2 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package metadata
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"github.com/kubeflow/pipelines/backend/src/common/util"
"os"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -111,7 +114,7 @@ type Client struct {
}

// NewClient creates a Client given the MLMD server address and port.
func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) {
func NewClient(serverAddress, serverPort string, tlsEnabled bool, caCertPath string) (*Client, error) {
opts := []grpc_retry.CallOption{
grpc_retry.WithMax(mlmdClientSideMaxRetries),
grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)),
Expand All @@ -120,8 +123,19 @@ func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, erro

creds := insecure.NewCredentials()
if tlsEnabled {
if caCertPath == "" {
return nil, errors.New("CA cert path is empty")
}

caCert, err := os.ReadFile(caCertPath)
if err != nil {
return nil, util.Wrap(err, "metadata.NewClient() failed")
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)

config := &tls.Config{
InsecureSkipVerify: true, // This should be removed by https://issues.redhat.com/browse/RHOAIENG-13871
RootCAs: caCertPool,
}
creds = credentials.NewTLS(config)
}
Expand Down
Loading

0 comments on commit bef5fa2

Please sign in to comment.