Skip to content

Commit

Permalink
UPSTREAM: <carry>: Added support for TLS to MLMD GRPC Server
Browse files Browse the repository at this point in the history
Signed-off-by: hbelmiro <[email protected]>
  • Loading branch information
hbelmiro committed Sep 24, 2024
1 parent 3139121 commit ff444eb
Show file tree
Hide file tree
Showing 17 changed files with 111 additions and 22 deletions.
3 changes: 2 additions & 1 deletion backend/src/apiserver/client_manager/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ func (c *ClientManager) init() {

c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams)

newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort())
newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled())

if err != nil {
glog.Fatalf("Failed to create metadata client. Error: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions backend/src/apiserver/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
TokenReviewAudience string = "TOKEN_REVIEW_AUDIENCE"
MetadataGrpcServiceServiceHost string = "METADATA_GRPC_SERVICE_SERVICE_HOST"
MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT"
MetadataTLSEnabled string = "METADATA_TLS_ENABLED"
SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS"
)

Expand Down Expand Up @@ -142,3 +143,7 @@ func GetMetadataGrpcServiceServicePort() string {
func GetSignedURLExpiryTimeSeconds() int {
return GetIntConfigWithDefault(SignedURLExpiryTimeSeconds, DefaultSignedURLExpiryTimeSeconds)
}

func GetMetadataTLSEnabled() bool {
return GetBoolConfigWithDefault(MetadataTLSEnabled, DefaultMetadataTLSEnabled)
}
1 change: 1 addition & 0 deletions backend/src/apiserver/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const DefaultTokenReviewAudience string = "pipelines.kubeflow.org"
const (
DefaultMetadataGrpcServiceServiceHost = "metadata-grpc-service"
DefaultMetadataGrpcServiceServicePort = "8080"
DefaultMetadataTLSEnabled = false
)

const (
Expand Down
17 changes: 16 additions & 1 deletion backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ var (
conditionPath = flag.String("condition_path", "", "Condition output path")

mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
)

// func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) {
Expand Down Expand Up @@ -154,6 +155,11 @@ func drive() (err error) {
return err
}

metadataTLSEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr)
if err != nil {
return err
}

cacheClient, err := cacheutils.NewClient(mlPipelineServiceTLSEnabled)
if err != nil {
return err
Expand All @@ -167,6 +173,9 @@ func drive() (err error) {
DAGExecutionID: *dagExecutionID,
IterationIndex: *iterationIndex,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
MLMDServerAddress: *mlmdServerAddress,
MLMDServerPort: *mlmdServerPort,
MLMDTLSEnabled: metadataTLSEnabled,
}
var execution *driver.Execution
var driverErr error
Expand Down Expand Up @@ -292,5 +301,11 @@ func newMlmdClient() (*metadata.Client, error) {
mlmdConfig.Address = *mlmdServerAddress
mlmdConfig.Port = *mlmdServerPort
}
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port)

tlsEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr)
if err != nil {
return nil, err
}

return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled)
}
8 changes: 8 additions & 0 deletions backend/src/v2/cmd/launcher-v2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.")
mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.")
mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
)

func main() {
Expand Down Expand Up @@ -71,6 +72,12 @@ func run() error {
if err != nil {
return err
}

metadataServiceTLSEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr)
if err != nil {
return err
}

launcherV2Opts := &component.LauncherV2Options{
Namespace: namespace,
PodName: *podName,
Expand All @@ -80,6 +87,7 @@ func run() error {
PipelineName: *pipelineName,
RunID: *runID,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
MetadataTLSEnabled: metadataServiceTLSEnabled,
}

switch *executorType {
Expand Down
4 changes: 4 additions & 0 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package argocompiler
import (
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"github.com/kubeflow/pipelines/backend/src/v2/component"
k8score "k8s.io/api/core/v1"
"os"
Expand Down Expand Up @@ -163,6 +164,9 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
"--condition_path", outputPath(paramCondition),
"--kubernetes_config", inputValue(paramKubernetesConfig),
"--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
},
Resources: driverResources,
},
Expand Down
4 changes: 4 additions & 0 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -443,6 +444,9 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
"--iteration_count_path", outputPath(paramIterationCount),
"--condition_path", outputPath(paramCondition),
"--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
},
Resources: driverResources,
},
Expand Down
7 changes: 3 additions & 4 deletions backend/src/v2/compiler/argocompiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
Expand Down Expand Up @@ -76,10 +77,8 @@ func (c *workflowCompiler) addImporterTemplate() string {
fmt.Sprintf("$(%s)", component.EnvPodName),
"--pod_uid",
fmt.Sprintf("$(%s)", component.EnvPodUID),
"--mlmd_server_address",
fmt.Sprintf("$(%s)", component.EnvMetadataHost),
"--mlmd_server_port",
fmt.Sprintf("$(%s)", component.EnvMetadataPort),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
}
importerTemplate := &wfapi.Template{
Name: name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ spec:
- '{{inputs.parameters.kubernetes-config}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -312,6 +318,12 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
12 changes: 12 additions & 0 deletions backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ spec:
- '{{inputs.parameters.kubernetes-config}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -242,6 +248,12 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
env:
- name: ML_PIPELINE_SERVICE_HOST
value: ml-pipeline.kubeflow.svc.cluster.local
Expand Down
10 changes: 8 additions & 2 deletions backend/src/v2/compiler/argocompiler/testdata/importer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ spec:
- --pod_uid
- $(KFP_POD_UID)
- --mlmd_server_address
- $(METADATA_GRPC_SERVICE_HOST)
- "metadata-grpc-service"
- --mlmd_server_port
- $(METADATA_GRPC_SERVICE_PORT)
- "8080"
command:
- launcher-v2
env:
Expand Down Expand Up @@ -120,6 +120,12 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort)
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type LauncherV2Options struct {
RunID string
// set to true if ml pipeline server is serving over tls
MLPipelineTLSEnabled bool
// set to true if metadata server is serving over tls
MetadataTLSEnabled bool
}

type LauncherV2 struct {
Expand Down Expand Up @@ -110,7 +112,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort)
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled)
if err != nil {
return nil, err
}
Expand Down
19 changes: 14 additions & 5 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ type Options struct {

// set to true if ml pipeline server is serving over tls
MLPipelineTLSEnabled bool

MLMDServerAddress string

MLMDServerPort string

// set to true if MLMD server is serving over tls
MLMDTLSEnabled bool
}

// Identifying information used for error messages
Expand Down Expand Up @@ -339,7 +346,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
return execution, nil
}

podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled)
podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled)
if err != nil {
return execution, err
}
Expand Down Expand Up @@ -373,6 +380,9 @@ func initPodSpecPatch(
pipelineName string,
runID string,
mlPipelineTLSEnabled bool,
mlmdServerAddress string,
mlmdServerPort string,
mlmdTLSEnabled bool,
) (*k8score.PodSpec, error) {
executorInputJSON, err := protojson.Marshal(executorInput)
if err != nil {
Expand Down Expand Up @@ -407,10 +417,9 @@ func initPodSpecPatch(
fmt.Sprintf("$(%s)", component.EnvPodName),
"--pod_uid",
fmt.Sprintf("$(%s)", component.EnvPodUID),
"--mlmd_server_address",
fmt.Sprintf("$(%s)", component.EnvMetadataHost),
"--mlmd_server_port",
fmt.Sprintf("$(%s)", component.EnvMetadataPort),
"--mlmd_server_address", mlmdServerAddress,
"--mlmd_server_port", mlmdServerPort,
"--metadataTLSEnabled", fmt.Sprintf("%v", mlmdTLSEnabled),
"--mlPipelineServiceTLSEnabled",
fmt.Sprintf("%v", mlPipelineTLSEnabled),
"--", // separater before user command and args
Expand Down
4 changes: 2 additions & 2 deletions backend/src/v2/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false)
if tt.wantErr {
assert.Nil(t, podSpec)
assert.NotNil(t, err)
Expand Down Expand Up @@ -403,7 +403,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false)
assert.Nil(t, err)
assert.NotEmpty(t, podSpec)
podSpecString, err := json.Marshal(podSpec)
Expand Down
13 changes: 12 additions & 1 deletion backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ package metadata

import (
"context"
"crypto/tls"
"errors"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -105,14 +108,22 @@ type Client struct {
}

// NewClient creates a Client given the MLMD server address and port.
func NewClient(serverAddress, serverPort string) (*Client, error) {
func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) {
opts := []grpc_retry.CallOption{
grpc_retry.WithMax(mlmdClientSideMaxRetries),
grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)),
grpc_retry.WithCodes(codes.Aborted),
}

creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{}
creds = credentials.NewTLS(config)
}

conn, err := grpc.Dial(fmt.Sprintf("%s:%s", serverAddress, serverPort),
grpc.WithInsecure(),
grpc.WithTransportCredentials(creds),
grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)),
grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)),
)
Expand Down
Loading

0 comments on commit ff444eb

Please sign in to comment.